From 8bfe3d58b6c5f3f2b2d007d13206ec5a6c0b9916 Mon Sep 17 00:00:00 2001 From: Pavel Anni Date: Sat, 14 Dec 2024 18:26:51 -0500 Subject: [PATCH] add lab manager --- .gitignore | 4 + cmd/create-key.go | 174 -------- cmd/create-lab.go | 200 --------- cmd/create_key.go | 131 ++++++ cmd/create_lab.go | 140 ++++++ cmd/create_lab_test.go | 91 ++++ cmd/{create-server.go => create_server.go} | 73 ++-- cmd/{create-volume.go => create_volume.go} | 0 cmd/delete.go | 9 +- cmd/{delete-key.go => delete_key.go} | 24 +- cmd/{delete-lab.go => delete_lab.go} | 6 +- cmd/delete_lab_test.go | 117 +++++ cmd/{delete-server.go => delete_server.go} | 0 cmd/{delete-volume.go => delete_volume.go} | 0 cmd/{get-key.go => get_key.go} | 0 cmd/{get-lab.go => get_lab.go} | 5 +- cmd/get_lab_test.go | 214 ++++++++++ cmd/{get-server.go => get_server.go} | 0 cmd/{get-volume.go => get_volume.go} | 0 cmd/init.go | 52 +++ cmd/root.go | 16 +- cmd/sync.go | 2 +- containers/Containerfile.fedora | 67 +++ internal/config/config.go | 17 +- internal/config/constants.go | 13 +- internal/lab/lab.go | 398 ++++++++++++++++++ internal/lab/mock/manager.go | 36 ++ internal/provider/hetzner/lab.go | 239 ----------- internal/provider/hetzner/provider.go | 31 +- internal/provider/hetzner/server.go | 20 + internal/provider/hetzner/sshkey.go | 65 ++- internal/provider/mock/mock_provider.go | 200 +++++++++ internal/provider/mock/mock_provider_test.go | 112 +++++ internal/provider/types.go | 13 +- internal/ssh/keys.go | 159 +++++++ internal/types/errors.go | 19 + internal/types/types.go | 10 +- internal/util/serverchecker/serverchecker.go | 13 +- .../util/serverchecker/serverchecker_test.go | 13 +- 39 files changed, 1933 insertions(+), 750 deletions(-) delete mode 100644 cmd/create-key.go delete mode 100644 cmd/create-lab.go create mode 100644 cmd/create_key.go create mode 100644 cmd/create_lab.go create mode 100644 cmd/create_lab_test.go rename cmd/{create-server.go => create_server.go} (71%) rename cmd/{create-volume.go => create_volume.go} (100%) rename cmd/{delete-key.go => delete_key.go} (60%) rename cmd/{delete-lab.go => delete_lab.go} (79%) create mode 100644 cmd/delete_lab_test.go rename cmd/{delete-server.go => delete_server.go} (100%) rename cmd/{delete-volume.go => delete_volume.go} (100%) rename cmd/{get-key.go => get_key.go} (100%) rename cmd/{get-lab.go => get_lab.go} (94%) create mode 100644 cmd/get_lab_test.go rename cmd/{get-server.go => get_server.go} (100%) rename cmd/{get-volume.go => get_volume.go} (100%) create mode 100644 containers/Containerfile.fedora create mode 100644 internal/lab/lab.go create mode 100644 internal/lab/mock/manager.go delete mode 100644 internal/provider/hetzner/lab.go create mode 100644 internal/provider/mock/mock_provider.go create mode 100644 internal/provider/mock/mock_provider_test.go create mode 100644 internal/ssh/keys.go create mode 100644 internal/types/errors.go diff --git a/.gitignore b/.gitignore index 658d1f3..5f05d73 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,8 @@ go.work.sum !/config/*.example.yaml *.db +# GoReleaser build artifacts dist/ + +# Scratch space +scratch/ diff --git a/cmd/create-key.go b/cmd/create-key.go deleted file mode 100644 index 99bf7b8..0000000 --- a/cmd/create-key.go +++ /dev/null @@ -1,174 +0,0 @@ -package cmd - -import ( - "crypto" - "crypto/ed25519" - "crypto/rand" - "encoding/base64" - "encoding/pem" - "fmt" - "os" - "path/filepath" - "time" - - "github.com/pavelanni/storctl/internal/config" - "github.com/pavelanni/storctl/internal/provider/options" - "github.com/pavelanni/storctl/internal/types" - "github.com/pavelanni/storctl/internal/util/labelutil" - "github.com/pavelanni/storctl/internal/util/timeutil" - "github.com/spf13/cobra" - "golang.org/x/crypto/ssh" -) - -func NewCreateKeyCmd() *cobra.Command { - var labels map[string]string - var ttl string - - cmd := &cobra.Command{ - Use: "key [name]", - - Short: "Create and upload an SSH key pair", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - keyName := args[0] - keyResource := &types.SSHKey{ - TypeMeta: types.TypeMeta{ - APIVersion: "v1", - Kind: "SSHKey", - }, - ObjectMeta: types.ObjectMeta{ - Name: keyName, - Labels: labels, - }, - } - key, err := createKey(keyResource) - if err != nil { - return err - } - fmt.Printf("SSH key created successfully: %s\n", key.ObjectMeta.Name) - return nil - }, - } - - cmd.Flags().StringToStringVar(&labels, "labels", map[string]string{}, "SSH key labels") - cmd.Flags().StringVar(&ttl, "ttl", config.DefaultTTL, "Time to live for the key") - return cmd -} - -func createKey(key *types.SSHKey) (*types.SSHKey, error) { - keyName := key.ObjectMeta.Name - if keyName == "" { - return nil, fmt.Errorf("key name is required") - } - - // Check if key already exists locally - keysDir := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.KeysDir) - localKeyPath := filepath.Join(keysDir, keyName) - if _, err := os.Stat(localKeyPath); err == nil { - return nil, fmt.Errorf("key %s already exists locally at %s", keyName, localKeyPath) - } - - // Check if key already exists on the provider - exists, err := providerSvc.KeyExists(keyName) - if err != nil { - return nil, fmt.Errorf("failed to check if key exists on provider: %w", err) - } - if exists { - return nil, fmt.Errorf("key %s already exists on the provider", keyName) - } - - fmt.Printf("Creating key %s\n", keyName) - labels := key.ObjectMeta.Labels - var ttl string - if key.Spec.TTL == "" { - ttl = config.DefaultTTL - } else { - ttl = key.Spec.TTL - } - duration, err := timeutil.TtlToDuration(ttl) - if err != nil { - return nil, fmt.Errorf("failed to parse ttl: %w", err) - } - labels["delete_after"] = timeutil.FormatDeleteAfter(time.Now().Add(duration)) - labels["owner"] = labelutil.SanitizeValue(cfg.Owner) - pubKeyString := key.Spec.PublicKey - // If public key is not provided, generate a new key pair - if pubKeyString == "" { - // Generate the key pair - pubKey, privKey, err := generateED25519KeyPair(keyName) - if err != nil { - return nil, fmt.Errorf("failed to generate key pair: %w", err) - } - - // Save the keys locally - keysDir := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.KeysDir) - if err := os.MkdirAll(keysDir, 0700); err != nil { - return nil, fmt.Errorf("failed to create keys directory: %w", err) - } - - // Save private key - privKeyPath := filepath.Join(keysDir, keyName) - if err := os.WriteFile(privKeyPath, privKey, 0600); err != nil { - return nil, fmt.Errorf("failed to save private key: %w", err) - } - - // Save public key - pubKeyPath := filepath.Join(keysDir, keyName+".pub") - if err := os.WriteFile(pubKeyPath, pubKey, 0644); err != nil { - return nil, fmt.Errorf("failed to save public key: %w", err) - } - pubKeyString = string(pubKey) - fmt.Printf("SSH key pair created successfully: %s\n", keyName) - key.Spec.PublicKey = pubKeyString - } - - // Upload public key to provider - key, err = providerSvc.CreateSSHKey(options.SSHKeyCreateOpts{ - Name: keyName, - PublicKey: pubKeyString, - Labels: labels, - }) - if err != nil { - return nil, fmt.Errorf("failed to upload public key: %w", err) - } - - fmt.Printf("SSH key uploaded to provider: %s\n", keyName) - return key, nil -} - -// generateED25519KeyPair generates a new ED25519 keypair. -// Returns public key in OpenSSH format and private key in PEM format as byte slices. -func generateED25519KeyPair(comment string) (publicKey, privateKey []byte, err error) { - // Generate the keypair - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, nil, fmt.Errorf("failed to generate ED25519 keypair: %w", err) - } - - // Convert to SSH private key format and encode as PEM - pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(priv), comment) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) - } - - // Encode private key in PEM format - privateKey = pem.EncodeToMemory(pemBlock) - if privateKey == nil { - return nil, nil, fmt.Errorf("failed to encode private key") - } - - // Generate the public key - sshPub, err := ssh.NewPublicKey(pub) - if err != nil { - return nil, nil, fmt.Errorf("failed to create public key: %w", err) - } - - // Format public key in OpenSSH format: "ssh-ed25519 comment" - pubKey := fmt.Sprintf("%s %s", sshPub.Type(), - base64.StdEncoding.EncodeToString(sshPub.Marshal())) - if comment != "" { - pubKey = fmt.Sprintf("%s %s", pubKey, comment) - } - - return []byte(pubKey), privateKey, nil -} diff --git a/cmd/create-lab.go b/cmd/create-lab.go deleted file mode 100644 index b634dab..0000000 --- a/cmd/create-lab.go +++ /dev/null @@ -1,200 +0,0 @@ -package cmd - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "strings" - "time" - - "github.com/pavelanni/storctl/internal/config" - "github.com/pavelanni/storctl/internal/types" - "github.com/pavelanni/storctl/internal/util/labelutil" - "github.com/pavelanni/storctl/internal/util/serverchecker" - "github.com/pavelanni/storctl/internal/util/timeutil" - "github.com/spf13/cobra" - "k8s.io/apimachinery/pkg/util/yaml" -) - -func NewCreateLabCmd() *cobra.Command { - var ( - template string - name string - provider string - location string - ttl string - ) - - cmd := &cobra.Command{ - Use: "lab [name]", - Short: "Create a new lab environment", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - name = args[0] - lab, err := labFromTemplate(template, name, provider, location, ttl) - if err != nil { - return fmt.Errorf("error parsing lab template: %w", err) - } - _, err = createLab(lab) - if err != nil { - return fmt.Errorf("error creating lab: %w", err) - } - return nil - }, - } - - defaultTemplate := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.DefaultTemplateDir, "lab.yaml") - cmd.Flags().StringVar(&template, "template", defaultTemplate, "lab template to use") - cmd.Flags().StringVar(&provider, "provider", config.DefaultProvider, "provider to use") - cmd.Flags().StringVar(&location, "location", config.DefaultLocation, "location to use") - cmd.Flags().StringVar(&ttl, "ttl", config.DefaultTTL, "ttl to use") - - return cmd -} - -func createLab(lab *types.Lab) (*types.Lab, error) { - lab.ObjectMeta.Labels["owner"] = labelutil.SanitizeValue(cfg.Owner) - lab.ObjectMeta.Labels["organization"] = labelutil.SanitizeValue(cfg.Organization) - lab.ObjectMeta.Labels["email"] = labelutil.SanitizeValue(cfg.Email) - lab.ObjectMeta.Labels["lab_name"] = lab.ObjectMeta.Name - ttl := lab.Spec.TTL - if ttl == "" { - ttl = config.DefaultTTL - } - duration, err := timeutil.TtlToDuration(ttl) - if err != nil { - return nil, fmt.Errorf("failed to parse ttl: %w", err) - } - lab.ObjectMeta.Labels["delete_after"] = timeutil.FormatDeleteAfter(time.Now().Add(duration)) - - keyNames := []string{strings.Join([]string{lab.ObjectMeta.Name, "admin"}, "-")} - // Create servers - specServers := lab.Spec.Servers - servers := make([]*types.Server, 0) - for _, serverSpec := range specServers { - s := &types.Server{ - TypeMeta: types.TypeMeta{ - Kind: "Server", - APIVersion: "v1", - }, - ObjectMeta: types.ObjectMeta{ - Name: strings.Join([]string{lab.ObjectMeta.Name, serverSpec.Name}, "-"), - Labels: lab.ObjectMeta.Labels, - }, - Spec: types.ServerSpec{ - Location: lab.Spec.Location, - Provider: lab.Spec.Provider, - ServerType: serverSpec.ServerType, - TTL: ttl, - Image: serverSpec.Image, - SSHKeyNames: keyNames, - }, - } - result, err := createServer(s) - if err != nil { - return nil, err - } - if err := addDNSRecord(result); err != nil { - return nil, err - } - servers = append(servers, result) - } - // Add a DNS record for 'aistor.' using the IP of the control plane server - cpPublicNet := servers[0].Status.PublicNet - aistorServer := &types.Server{ - ObjectMeta: types.ObjectMeta{ - Name: strings.Join([]string{lab.ObjectMeta.Name, "aistor"}, "-"), - Labels: lab.ObjectMeta.Labels, - }, - Status: types.ServerStatus{ - PublicNet: cpPublicNet, - }, - } - if err := addDNSRecord(aistorServer); err != nil { - return nil, err - } - - // Wait for servers to be ready - timeout := 30 * time.Minute - attempts := 20 - results, err := serverchecker.CheckServers(servers, cfg.LogLevel, timeout, attempts) - if err != nil { - return nil, err - } - for _, result := range results { - fmt.Printf("Server %s: Ready: %v\n", result.Server.ObjectMeta.Name, result.Ready) - if !result.Ready { - return nil, fmt.Errorf("server %s not ready", result.Server.ObjectMeta.Name) - } - } - - // Create volumes - volumes := lab.Spec.Volumes - for _, volumeSpec := range volumes { - if !volumeSpec.Automount { // if not specified, default to false - volumeSpec.Automount = config.DefaultVolumeAutomount - } - if volumeSpec.Format == "" { // if not specified, default to xfs - volumeSpec.Format = config.DefaultVolumeFormat - } - v := &types.Volume{ - TypeMeta: types.TypeMeta{ - Kind: "Volume", - APIVersion: "v1", - }, - ObjectMeta: types.ObjectMeta{ - Name: strings.Join([]string{lab.ObjectMeta.Name, volumeSpec.Name}, "-"), - Labels: lab.ObjectMeta.Labels, - }, - Spec: types.VolumeSpec{ - Size: volumeSpec.Size, - ServerName: strings.Join([]string{lab.ObjectMeta.Name, volumeSpec.Server}, "-"), - Automount: volumeSpec.Automount, - Format: volumeSpec.Format, - }, - } - if err := createVolume(v); err != nil { - return nil, err - } - } - return lab, nil -} - -func labFromTemplate(template, name, provider, location, ttl string) (*types.Lab, error) { - data, err := os.ReadFile(template) - if err != nil { - return nil, fmt.Errorf("error reading file: %w", err) - } - - decoder := yaml.NewYAMLOrJSONDecoder(bytes.NewBuffer(data), 4096) - lab := &types.Lab{} - if err := decoder.Decode(lab); err != nil { - return nil, fmt.Errorf("error decoding YAML: %w", err) - } - lab.ObjectMeta.Name = name - lab.Spec.Provider = provider - lab.Spec.Location = location - lab.Spec.TTL = ttl - return lab, nil -} - -func addDNSRecord(server *types.Server) error { - labName, ok := server.ObjectMeta.Labels["lab_name"] - if !ok { - labName = "no-lab" - } - labName = strings.ToLower(labName) - serverName := strings.ToLower(server.Name) - // remove the leading labName with "-" from the serverName - serverName = strings.TrimPrefix(serverName, labName+"-") - err := dnsSvc.AddRecord(cfg.DNS.ZoneID, - strings.Join([]string{serverName, labName}, "."), - "A", - server.Status.PublicNet.IPv4.IP, - false) - if err != nil { - return err - } - return nil -} diff --git a/cmd/create_key.go b/cmd/create_key.go new file mode 100644 index 0000000..169b3e4 --- /dev/null +++ b/cmd/create_key.go @@ -0,0 +1,131 @@ +package cmd + +import ( + "fmt" + "time" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/ssh" + "github.com/pavelanni/storctl/internal/types" + "github.com/pavelanni/storctl/internal/util/labelutil" + "github.com/pavelanni/storctl/internal/util/timeutil" + "github.com/spf13/cobra" +) + +func NewCreateKeyCmd() *cobra.Command { + var labels map[string]string + var ttl string + + cmd := &cobra.Command{ + Use: "key [name]", + + Short: "Create and upload an SSH key pair", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + keyName := args[0] + keyResource := &types.SSHKey{ + TypeMeta: types.TypeMeta{ + APIVersion: "v1", + Kind: "SSHKey", + }, + ObjectMeta: types.ObjectMeta{ + Name: keyName, + Labels: labels, + }, + } + key, err := createKey(keyResource) + if err != nil { + return err + } + fmt.Printf("SSH key created successfully: %s\n", key.ObjectMeta.Name) + return nil + }, + } + + cmd.Flags().StringToStringVar(&labels, "labels", map[string]string{}, "SSH key labels") + cmd.Flags().StringVar(&ttl, "ttl", config.DefaultTTL, "Time to live for the key") + return cmd +} + +func createKey(key *types.SSHKey) (*types.SSHKey, error) { + keyManager := ssh.NewManager(cfg) + keyName := key.ObjectMeta.Name + if keyName == "" { + return nil, fmt.Errorf("key name is required") + } + + // If public key is not provided, generate a new key pair + if key.Spec.PublicKey == "" { + // Check if the key already exists locally + // create the key pair if it doesn't exist + localKeyExists, err := keyManager.LocalKeyExists(keyName) + if err != nil { + return nil, fmt.Errorf("failed to check if key exists locally: %w", err) + } + if !localKeyExists { + fmt.Printf("Creating keypair %s locally\n", keyName) + pubKey, err := keyManager.CreateLocalKeyPair(keyName) + if err != nil { + return nil, fmt.Errorf("failed to create keypair: %w", err) + } + key.Spec.PublicKey = pubKey + } else { + pubKey, err := keyManager.ReadLocalPublicKey(keyName) + if err != nil { + return nil, fmt.Errorf("failed to read local public key: %w", err) + } + key.Spec.PublicKey = pubKey + } + } + + // Check if key already exists on the provider + keyExists, err := providerSvc.CloudKeyExists(keyName) + if err != nil { + return nil, fmt.Errorf("failed to check if key exists on provider: %w", err) + } + if keyExists { + fmt.Printf("SSH key %s already exists on the provider\n", keyName) + cloudKey, err := providerSvc.GetSSHKey(keyName) + if err != nil { + return nil, fmt.Errorf("failed to get key from provider: %w", err) + } + // is it the same key? + if cloudKey.Spec.PublicKey == key.Spec.PublicKey { + return key, nil + } else { + fmt.Printf("SSH key %s already exists on the provider but is different from the local key. Replacing it.\n", keyName) + err := providerSvc.DeleteSSHKey(keyName, true) + if err != nil { + return nil, fmt.Errorf("failed to delete key from provider: %w", err) + } + } + } + + fmt.Printf("Creating SSH key %s on provider\n", keyName) + labels := key.ObjectMeta.Labels + var ttl string + if key.Spec.TTL == "" { + ttl = config.DefaultTTL + } else { + ttl = key.Spec.TTL + } + duration, err := timeutil.TtlToDuration(ttl) + if err != nil { + return nil, fmt.Errorf("failed to parse ttl: %w", err) + } + labels["delete_after"] = timeutil.FormatDeleteAfter(time.Now().Add(duration)) + labels["owner"] = labelutil.SanitizeValue(cfg.Owner) + // Upload public key to provider + key, err = providerSvc.CreateSSHKey(options.SSHKeyCreateOpts{ + Name: keyName, + PublicKey: key.Spec.PublicKey, + Labels: labels, + }) + if err != nil { + return nil, fmt.Errorf("failed to upload public key: %w", err) + } + + fmt.Printf("SSH key uploaded to provider: %s\n", keyName) + return key, nil +} diff --git a/cmd/create_lab.go b/cmd/create_lab.go new file mode 100644 index 0000000..47bc848 --- /dev/null +++ b/cmd/create_lab.go @@ -0,0 +1,140 @@ +package cmd + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/lab" + "github.com/pavelanni/storctl/internal/types" + "github.com/pavelanni/storctl/internal/util/labelutil" + "github.com/pavelanni/storctl/internal/util/timeutil" + "github.com/spf13/cobra" + "k8s.io/apimachinery/pkg/util/yaml" +) + +func NewCreateLabCmd() *cobra.Command { + var ( + template string + name string + provider string + location string + ttl string + ) + + cmd := &cobra.Command{ + Use: "lab [name]", + Short: "Create a new lab environment", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name = args[0] + lab, err := labFromTemplate(template, name, provider, location, ttl) + if err != nil { + return fmt.Errorf("error parsing lab template: %w", err) + } + _, err = createLab(lab) + if err != nil { + return fmt.Errorf("error creating lab: %w", err) + } + return nil + }, + } + + defaultTemplate := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.DefaultTemplateDir, "lab.yaml") + cmd.Flags().StringVar(&template, "template", defaultTemplate, "lab template to use") + cmd.Flags().StringVar(&provider, "provider", config.DefaultProvider, "provider to use") + cmd.Flags().StringVar(&location, "location", config.DefaultLocation, "location to use") + cmd.Flags().StringVar(&ttl, "ttl", config.DefaultTTL, "ttl to use") + + return cmd +} + +func createLab(newLab *types.Lab) (*types.Lab, error) { + newLab.ObjectMeta.Labels["owner"] = labelutil.SanitizeValue(cfg.Owner) + newLab.ObjectMeta.Labels["organization"] = labelutil.SanitizeValue(cfg.Organization) + newLab.ObjectMeta.Labels["email"] = labelutil.SanitizeValue(cfg.Email) + newLab.ObjectMeta.Labels["lab_name"] = newLab.ObjectMeta.Name + ttl := newLab.Spec.TTL + if ttl == "" { + ttl = config.DefaultTTL + } + duration, err := timeutil.TtlToDuration(ttl) + if err != nil { + return nil, fmt.Errorf("failed to parse ttl: %w", err) + } + newLab.ObjectMeta.Labels["delete_after"] = timeutil.FormatDeleteAfter(time.Now().Add(duration)) + + labManager, err := lab.NewManager(providerSvc, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create lab manager: %w", err) + } + if err := labManager.Create(newLab); err != nil { + return nil, err + } + if err := addDNSRecords(newLab); err != nil { + return nil, err + } + return newLab, nil +} + +func labFromTemplate(template, name, provider, location, ttl string) (*types.Lab, error) { + // Check if the template file exists + if _, err := os.Stat(template); os.IsNotExist(err) { + // Check if it exists in the default template directory + tmpl := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.DefaultTemplateDir, template) + if _, err := os.Stat(tmpl); os.IsNotExist(err) { + return nil, fmt.Errorf("template file does not exist: %s", tmpl) + } + template = tmpl + } + data, err := os.ReadFile(template) + if err != nil { + return nil, fmt.Errorf("error reading file: %w", err) + } + + decoder := yaml.NewYAMLOrJSONDecoder(bytes.NewBuffer(data), 4096) + lab := &types.Lab{} + if err := decoder.Decode(lab); err != nil { + return nil, fmt.Errorf("error decoding YAML: %w", err) + } + lab.ObjectMeta.Name = name + lab.Spec.Provider = provider + lab.Spec.Location = location + lab.Spec.TTL = ttl + return lab, nil +} + +func addDNSRecords(lab *types.Lab) error { + labName, ok := lab.ObjectMeta.Labels["lab_name"] + if !ok { + labName = "no-lab" + } + labName = strings.ToLower(labName) + for _, server := range lab.Status.Servers { + serverName := strings.ToLower(server.Name) + // remove the leading labName with "-" from the serverName + serverName = strings.TrimPrefix(serverName, labName+"-") + err := dnsSvc.AddRecord(cfg.DNS.ZoneID, + strings.Join([]string{serverName, labName}, "."), + "A", + server.Status.PublicNet.IPv4.IP, + false) + if err != nil { + return err + } + } + // Add a DNS record for 'aistor.' using the IP of the control plane server + cpPublicNet := lab.Status.Servers[0].Status.PublicNet + if err := dnsSvc.AddRecord(cfg.DNS.ZoneID, + strings.Join([]string{labName, "aistor"}, "."), + "A", + cpPublicNet.IPv4.IP, + false); err != nil { + return err + } + return nil +} diff --git a/cmd/create_lab_test.go b/cmd/create_lab_test.go new file mode 100644 index 0000000..4cf5985 --- /dev/null +++ b/cmd/create_lab_test.go @@ -0,0 +1,91 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/lab/mock" + "github.com/pavelanni/storctl/internal/types" +) + +func TestCreateLabCmd(t *testing.T) { + tests := []struct { + name string + args []string + mockSetup func(*mock.Manager) + wantErr bool + errContains string + }{ + { + name: "successful lab creation", + args: []string{"test-lab", "--template", "lab.yaml"}, + mockSetup: func(m *mock.Manager) { + m.CreateFunc = func(lab *types.Lab) error { + if lab.ObjectMeta.Name != "test-lab" { + t.Errorf("expected lab name 'test-lab', got '%s'", lab.ObjectMeta.Name) + } + return nil + } + }, + wantErr: false, + }, + { + name: "provider error", + args: []string{"test-lab", "--template", "lab.yaml"}, + mockSetup: func(m *mock.Manager) { + m.CreateFunc = func(lab *types.Lab) error { + return types.NewError("provider error", "failed to create lab") + } + }, + wantErr: true, + errContains: "failed to create lab", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a new mock provider for each test + mockManager := &mock.Manager{} + if tt.mockSetup != nil { + tt.mockSetup(mockManager) + } + + // Store the current provider and restore it after the test + originalManager := labManager + labManager = mockManager + defer func() { labManager = originalManager }() + + // Store the current config and restore it after the test + originalCfg := cfg + cfg = &config.Config{ + Owner: "test-owner", + Organization: "test-organization", + Email: "test-email", + } + defer func() { cfg = originalCfg }() + + // Create and execute the command + cmd := NewCreateLabCmd() + cmd.SetArgs(tt.args) + err := cmd.Execute() + + // Check error expectations + if (err != nil) != tt.wantErr { + t.Errorf("CreateLabCmd() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil && tt.errContains != "" { + if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error to contain '%s', got '%s'", tt.errContains, err.Error()) + } + } + }) + } +} + +// Helper function to check if a string contains another string +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/cmd/create-server.go b/cmd/create_server.go similarity index 71% rename from cmd/create-server.go rename to cmd/create_server.go index f65af67..7b241a4 100644 --- a/cmd/create-server.go +++ b/cmd/create_server.go @@ -6,6 +6,7 @@ import ( "github.com/pavelanni/storctl/internal/config" "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/ssh" "github.com/pavelanni/storctl/internal/types" "github.com/pavelanni/storctl/internal/util/labelutil" "github.com/pavelanni/storctl/internal/util/timeutil" @@ -54,20 +55,24 @@ func NewCreateServerCmd() *cobra.Command { }, } - cmd.Flags().StringSliceVar(&sshKeyNames, "ssh-keys", []string{}, "SSH key names to use (required)") + cmd.Flags().StringSliceVar(&sshKeyNames, "ssh-keys", []string{}, "SSH key names to use; if not provided, the admin key will be created") cmd.Flags().StringVar(&serverType, "type", config.DefaultServerType, "Server type") cmd.Flags().StringVar(&image, "image", config.DefaultImage, "Server image") cmd.Flags().StringVar(&location, "location", config.DefaultLocation, "Server location") cmd.Flags().StringVar(&ttl, "ttl", config.DefaultTTL, "Server TTL") cmd.Flags().StringToStringVar(&labels, "labels", map[string]string{}, "Server labels") - if err := cmd.MarkFlagRequired("ssh-keys"); err != nil { - panic(err) - } return cmd } func createServer(server *types.Server) (*types.Server, error) { + sshManager := ssh.NewManager(cfg) + // no ssh keys provided, use the admin key + if len(server.Spec.SSHKeyNames) == 0 { + serverKeyName := server.ObjectMeta.Name + "-admin" + fmt.Printf("No SSH keys provided, using default: %s\n", serverKeyName) + server.Spec.SSHKeyNames = []string{serverKeyName} + } // Access fields using map syntax fmt.Printf("Creating server %s with type %s, image %s, location %s, ssh keys %v\n", server.ObjectMeta.Name, @@ -76,11 +81,6 @@ func createServer(server *types.Server) (*types.Server, error) { server.Spec.Location, server.Spec.SSHKeyNames) - if len(server.Spec.SSHKeyNames) == 0 { - serverKeyName := server.ObjectMeta.Name + "-admin" - fmt.Printf("No SSH keys provided, using default: %s\n", serverKeyName) - server.Spec.SSHKeyNames = []string{serverKeyName} - } ttl := config.DefaultTTL if server.Spec.TTL != "" { ttl = server.Spec.TTL @@ -94,47 +94,30 @@ func createServer(server *types.Server) (*types.Server, error) { labels["delete_after"] = timeutil.FormatDeleteAfter(time.Now().Add(duration)) labels["owner"] = labelutil.SanitizeValue(cfg.Owner) - sshKeys := make([]*types.SSHKey, 0) + // create the ssh keys locally for _, sshKeyName := range server.Spec.SSHKeyNames { - keyExists, err := providerSvc.KeyExists(sshKeyName) + _, err := sshManager.CreateLocalKeyPair(sshKeyName) if err != nil { - return nil, err - } - if !keyExists { - fmt.Printf("Creating new SSH key: %s\n", sshKeyName) - newKey, err := createKey(&types.SSHKey{ - TypeMeta: types.TypeMeta{ - Kind: "SSHKey", - }, - ObjectMeta: types.ObjectMeta{ - Name: sshKeyName, - Labels: labels, - }, - }) - if err != nil { - return nil, err - } - sshKeys = append(sshKeys, newKey) - } else { - providerKey, err := providerSvc.GetSSHKey(sshKeyName) - if err != nil { - return nil, err - } - sshKeys = append(sshKeys, providerKey) + return nil, fmt.Errorf("failed to create local ssh key: %w", err) } } - - cloudInitUserData := fmt.Sprintf(config.DefaultCloudInitUserData, sshKeys[0].Spec.PublicKey) - result, err := providerSvc.CreateServer(options.ServerCreateOpts{ - Name: server.ObjectMeta.Name, - Type: server.Spec.ServerType, - Image: server.Spec.Image, - Location: server.Spec.Location, - Provider: server.Spec.Provider, - SSHKeys: sshKeys, - Labels: labels, - UserData: cloudInitUserData, + sshKeys, err := providerSvc.KeyNamesToSSHKeys(server.Spec.SSHKeyNames, options.SSHKeyCreateOpts{ + Labels: labels, }) + if err != nil { + return nil, fmt.Errorf("failed to upload ssh keys to the cloud: %w", err) + } + + // create the cloud init user data with the admin key + opts, err := providerSvc.ServerToCreateOpts(server) + if err != nil { + return nil, err + } + opts.SSHKeys = sshKeys + result, err := providerSvc.CreateServer(opts) + if err != nil { + return nil, err + } return result, err } diff --git a/cmd/create-volume.go b/cmd/create_volume.go similarity index 100% rename from cmd/create-volume.go rename to cmd/create_volume.go diff --git a/cmd/delete.go b/cmd/delete.go index db9eaae..4d2d5b2 100644 --- a/cmd/delete.go +++ b/cmd/delete.go @@ -6,6 +6,7 @@ import ( "io" "os" + "github.com/pavelanni/storctl/internal/lab" "github.com/pavelanni/storctl/internal/types" "github.com/spf13/cobra" "k8s.io/apimachinery/pkg/util/yaml" @@ -95,8 +96,12 @@ func processDeleteResource(resource *types.Resource, assumeYes, skipTimeCheck bo } return nil case "Lab": - if status := providerSvc.DeleteLab(resourceName, skipTimeCheck); status.Error != nil { - return fmt.Errorf("failed to delete lab: %w", status.Error) + labManager, err := lab.NewManager(providerSvc, cfg) + if err != nil { + return fmt.Errorf("failed to create lab manager: %w", err) + } + if err := labManager.Delete(resourceName, skipTimeCheck); err != nil { + return fmt.Errorf("failed to delete lab: %w", err) } return nil default: diff --git a/cmd/delete-key.go b/cmd/delete_key.go similarity index 60% rename from cmd/delete-key.go rename to cmd/delete_key.go index ae3b68d..4430c7b 100644 --- a/cmd/delete-key.go +++ b/cmd/delete_key.go @@ -2,11 +2,9 @@ package cmd import ( "fmt" - "os" - "path/filepath" "time" - "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/ssh" "github.com/spf13/cobra" ) @@ -34,22 +32,12 @@ func NewDeleteSSHKeyCmd() *cobra.Command { fmt.Printf("Key %s is not ready for deletion until %s UTC\n", keyName, status.DeleteAfter.Format("2006-01-02 15:04:05")) return nil } - privateKeyPath := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.KeysDir, keyName) - publicKeyPath := privateKeyPath + ".pub" - // Delete the key from the keys directory - // check if the file exists - if _, err := os.Stat(privateKeyPath); err == nil { - if err := os.Remove(privateKeyPath); err != nil { - return fmt.Errorf("failed to delete private key from the keys directory: %w", err) - } + // delete the local key pair + keyManager := ssh.NewManager(cfg) + err := keyManager.DeleteLocalKeyPair(keyName) + if err != nil { + return fmt.Errorf("failed to delete local key pair: %w", err) } - // Delete the public key from the keys directory - if _, err := os.Stat(publicKeyPath); err == nil { - if err := os.Remove(publicKeyPath); err != nil { - return fmt.Errorf("failed to delete public key from the keys directory: %w", err) - } - } - fmt.Printf("Successfully deleted key %s\n", keyName) return nil }, diff --git a/cmd/delete-lab.go b/cmd/delete_lab.go similarity index 79% rename from cmd/delete-lab.go rename to cmd/delete_lab.go index 83a3983..5eb9769 100644 --- a/cmd/delete-lab.go +++ b/cmd/delete_lab.go @@ -21,9 +21,9 @@ func NewDeleteLabCmd() *cobra.Command { return nil } - // Delete the lab using cloud provider - if status := providerSvc.DeleteLab(labName, skipTimeCheck); status.Error != nil { - return fmt.Errorf("failed to delete lab: %w", status.Error) + // Delete the lab using lab manager + if err := labManager.Delete(labName, skipTimeCheck); err != nil { + return fmt.Errorf("failed to delete lab: %w", err) } fmt.Printf("Successfully deleted lab %s\n", labName) diff --git a/cmd/delete_lab_test.go b/cmd/delete_lab_test.go new file mode 100644 index 0000000..16535c6 --- /dev/null +++ b/cmd/delete_lab_test.go @@ -0,0 +1,117 @@ +package cmd + +import ( + "testing" + + "github.com/pavelanni/storctl/internal/lab/mock" + "github.com/pavelanni/storctl/internal/types" +) + +func TestDeleteLabCmd(t *testing.T) { + tests := []struct { + name string + args []string + flags []string + mockSetup func(*mock.Manager) + wantErr bool + errContains string + }{ + { + name: "successful lab deletion", + args: []string{"test-lab"}, + flags: []string{"--yes"}, // Skip confirmation + mockSetup: func(m *mock.Manager) { + m.DeleteFunc = func(name string, force bool) error { + if name != "test-lab" { + t.Errorf("expected lab name 'test-lab', got '%s'", name) + } + return nil + } + }, + wantErr: false, + }, + { + name: "force deletion", + args: []string{"test-lab"}, + flags: []string{"--yes", "--force"}, + mockSetup: func(m *mock.Manager) { + m.DeleteFunc = func(name string, force bool) error { + if !force { + t.Error("expected force flag to be true") + } + return nil + } + }, + wantErr: false, + }, + { + name: "lab not found", + args: []string{"nonexistent-lab"}, + flags: []string{"--yes"}, + mockSetup: func(m *mock.Manager) { + m.DeleteFunc = func(name string, force bool) error { + return types.NewError("NOT_FOUND", "lab not found") + } + }, + wantErr: true, + errContains: "lab not found", + }, + { + name: "missing lab name", + args: []string{}, + flags: []string{"--yes"}, + mockSetup: func(m *mock.Manager) { + m.DeleteFunc = func(name string, force bool) error { + t.Error("DeleteLab should not be called when lab name is missing") + return nil + } + }, + wantErr: true, + errContains: "requires exactly 1 arg", + }, + { + name: "provider error", + args: []string{"test-lab"}, + flags: []string{"--yes"}, + mockSetup: func(m *mock.Manager) { + m.DeleteFunc = func(name string, force bool) error { + return types.NewError("PROVIDER_ERROR", "failed to delete lab") + } + }, + wantErr: true, + errContains: "failed to delete lab", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a new mock provider for each test + mockManager := &mock.Manager{} + if tt.mockSetup != nil { + tt.mockSetup(mockManager) + } + + // Store the current provider and restore it after the test + originalManager := labManager + labManager = mockManager + defer func() { labManager = originalManager }() + + // Create and execute the command + cmd := NewDeleteLabCmd() + cmd.SetArgs(append(tt.args, tt.flags...)) + err := cmd.Execute() + + // Check error expectations + if (err != nil) != tt.wantErr { + t.Errorf("DeleteLabCmd() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil && tt.errContains != "" { + if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error to contain '%s', got '%s'", tt.errContains, err.Error()) + } + } + }) + } +} diff --git a/cmd/delete-server.go b/cmd/delete_server.go similarity index 100% rename from cmd/delete-server.go rename to cmd/delete_server.go diff --git a/cmd/delete-volume.go b/cmd/delete_volume.go similarity index 100% rename from cmd/delete-volume.go rename to cmd/delete_volume.go diff --git a/cmd/get-key.go b/cmd/get_key.go similarity index 100% rename from cmd/get-key.go rename to cmd/get_key.go diff --git a/cmd/get-lab.go b/cmd/get_lab.go similarity index 94% rename from cmd/get-lab.go rename to cmd/get_lab.go index 8515df1..b95e938 100644 --- a/cmd/get-lab.go +++ b/cmd/get_lab.go @@ -6,7 +6,6 @@ import ( "text/tabwriter" "time" - "github.com/pavelanni/storctl/internal/provider/options" "github.com/pavelanni/storctl/internal/util/output" "github.com/pavelanni/storctl/internal/util/timeutil" "github.com/spf13/cobra" @@ -30,7 +29,7 @@ func NewGetLabCmd() *cobra.Command { } func listLabs() error { - labs, err := providerSvc.ListLabs(options.LabListOpts{}) + labs, err := labManager.List() if err != nil { return err } @@ -74,7 +73,7 @@ func listLabs() error { func getLab(labName string) error { fmt.Printf("Getting details for lab: %s\n", labName) - lab, err := providerSvc.GetLab(labName) + lab, err := labManager.Get(labName) if err != nil { return err } diff --git a/cmd/get_lab_test.go b/cmd/get_lab_test.go new file mode 100644 index 0000000..71431c2 --- /dev/null +++ b/cmd/get_lab_test.go @@ -0,0 +1,214 @@ +package cmd + +import ( + "testing" + "time" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/lab/mock" + "github.com/pavelanni/storctl/internal/types" +) + +func TestGetLabCmd(t *testing.T) { + // Create some test data + testTime := time.Now() + testLabs := []*types.Lab{ + { + ObjectMeta: types.ObjectMeta{ + Name: "test-lab-1", + Labels: map[string]string{ + "environment": "staging", + }, + }, + Status: types.LabStatus{ + Created: testTime, + State: "running", + Servers: []*types.Server{ + { + ObjectMeta: types.ObjectMeta{ + Name: "server-1", + }, + Spec: types.ServerSpec{ + ServerType: "cx11", + }, + Status: types.ServerStatus{ + Cores: 2, + Memory: 4, + Disk: 50, + }, + }, + }, + Volumes: []*types.Volume{ + { + ObjectMeta: types.ObjectMeta{ + Name: "volume-1", + }, + Spec: types.VolumeSpec{ + Size: 100, + }, + }, + }, + }, + }, + { + ObjectMeta: types.ObjectMeta{ + Name: "test-lab-2", + Labels: map[string]string{ + "environment": "production", + }, + }, + Status: types.LabStatus{ + Created: testTime, + State: "stopped", + }, + }, + } + + // Save original config and restore after tests + originalCfg := cfg + defer func() { cfg = originalCfg }() + + // Initialize config for tests + cfg = &config.Config{ + OutputFormat: "table", + } + + tests := []struct { + name string + args []string + outputFormat string + mockSetup func(*mock.Manager) + wantErr bool + errContains string + }{ + { + name: "list all labs", + args: []string{}, + mockSetup: func(m *mock.Manager) { + m.ListFunc = func() ([]*types.Lab, error) { + return testLabs, nil + } + }, + wantErr: false, + }, + { + name: "get specific lab", + args: []string{"test-lab-1"}, + outputFormat: "table", + mockSetup: func(m *mock.Manager) { + m.GetFunc = func(name string) (*types.Lab, error) { + if name == "test-lab-1" { + return testLabs[0], nil + } + return nil, types.NewError("NOT_FOUND", "lab not found") + } + }, + wantErr: false, + }, + { + name: "get lab in JSON format", + args: []string{"test-lab-1"}, + outputFormat: "json", + mockSetup: func(m *mock.Manager) { + m.GetFunc = func(name string) (*types.Lab, error) { + return testLabs[0], nil + } + }, + wantErr: false, + }, + { + name: "get lab in YAML format", + args: []string{"test-lab-1"}, + outputFormat: "yaml", + mockSetup: func(m *mock.Manager) { + m.GetFunc = func(name string) (*types.Lab, error) { + return testLabs[0], nil + } + }, + wantErr: false, + }, + { + name: "lab not found", + args: []string{"nonexistent-lab"}, + mockSetup: func(m *mock.Manager) { + m.GetFunc = func(name string) (*types.Lab, error) { + return nil, types.NewError("NOT_FOUND", "lab not found") + } + }, + wantErr: true, + errContains: "lab not found", + }, + { + name: "provider error during list", + args: []string{}, + mockSetup: func(m *mock.Manager) { + m.ListFunc = func() ([]*types.Lab, error) { + return nil, types.NewError("PROVIDER_ERROR", "failed to list labs") + } + }, + wantErr: true, + errContains: "failed to list labs", + }, + { + name: "get lab from cloud", + args: []string{"test-lab-1", "--from-cloud"}, + mockSetup: func(m *mock.Manager) { + m.GetFromCloudFunc = func(name string) (*types.Lab, error) { + if name == "test-lab-1" { + return testLabs[0], nil + } + return nil, types.NewError("NOT_FOUND", "lab not found in cloud") + } + }, + wantErr: false, + }, + { + name: "get lab from cloud - not found", + args: []string{"nonexistent-lab", "--from-cloud"}, + mockSetup: func(m *mock.Manager) { + m.GetFromCloudFunc = func(name string) (*types.Lab, error) { + return nil, types.NewError("NOT_FOUND", "lab not found in cloud") + } + }, + wantErr: true, + errContains: "lab not found in cloud", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a new mock lab manager for each test + mockManager := &mock.Manager{} + if tt.mockSetup != nil { + tt.mockSetup(mockManager) + } + + // Set output format for this test + cfg.OutputFormat = tt.outputFormat + + // Store the current lab manager and restore it after the test + // TODO: fix it later + //originalManager := lab.DefaultManager + //lab.DefaultManager = &lab.ManagerSvc{ + // Provider: mockManager, + //} + //defer func() { lab.DefaultManager = originalManager }() + + // Create and execute the command + cmd := NewGetLabCmd() + cmd.SetArgs(tt.args) + err := cmd.Execute() + + if (err != nil) != tt.wantErr { + t.Errorf("GetLabCmd() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil && tt.errContains != "" { + if !contains(err.Error(), tt.errContains) { + t.Errorf("expected error to contain '%s', got '%s'", tt.errContains, err.Error()) + } + } + }) + } +} diff --git a/cmd/get-server.go b/cmd/get_server.go similarity index 100% rename from cmd/get-server.go rename to cmd/get_server.go diff --git a/cmd/get-volume.go b/cmd/get_volume.go similarity index 100% rename from cmd/get-volume.go rename to cmd/get_volume.go diff --git a/cmd/init.go b/cmd/init.go index d5bbcc0..58d73c9 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -22,6 +22,12 @@ func NewInitCmd() *cobra.Command { if err := createTemplates(); err != nil { return fmt.Errorf("error creating templates: %w", err) } + if err := createDefaultKeysDir(); err != nil { + return fmt.Errorf("error creating default keys directory: %w", err) + } + if err := createDefaultLabStorage(); err != nil { + return fmt.Errorf("error creating default lab storage: %w", err) + } return nil }, } @@ -108,3 +114,49 @@ func createTemplates() error { fmt.Printf("Lab template file created at %s\n", labTemplateFile) return nil } + +func createDefaultKeysDir() error { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("error getting home directory: %w", err) + } + configDir := filepath.Join(home, config.DefaultConfigDir) + if _, err := os.Stat(configDir); os.IsNotExist(err) { + err = os.MkdirAll(configDir, 0755) + if err != nil { + return fmt.Errorf("error creating config directory: %w", err) + } + } + defaultKeysDir := filepath.Join(configDir, config.DefaultKeysDir) + if _, err := os.Stat(defaultKeysDir); os.IsNotExist(err) { + err = os.MkdirAll(defaultKeysDir, 0700) + if err != nil { + return fmt.Errorf("error creating default keys directory: %w", err) + } + } + fmt.Printf("Default keys directory created at %s\n", defaultKeysDir) + return nil +} + +func createDefaultLabStorage() error { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("error getting home directory: %w", err) + } + configDir := filepath.Join(home, config.DefaultConfigDir) + if _, err := os.Stat(configDir); os.IsNotExist(err) { + err = os.MkdirAll(configDir, 0755) + if err != nil { + return fmt.Errorf("error creating config directory: %w", err) + } + } + labStorageFile := filepath.Join(configDir, config.DefaultLabStorageFile) + if _, err := os.Stat(labStorageFile); os.IsNotExist(err) { + err = os.WriteFile(labStorageFile, []byte(""), 0600) + if err != nil { + return fmt.Errorf("error writing default lab storage file: %w", err) + } + } + fmt.Printf("Default lab storage file created at %s\n", labStorageFile) + return nil +} diff --git a/cmd/root.go b/cmd/root.go index 2f0b41b..1acc154 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( "github.com/pavelanni/storctl/internal/config" "github.com/pavelanni/storctl/internal/dns" + "github.com/pavelanni/storctl/internal/lab" "github.com/pavelanni/storctl/internal/provider" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -18,6 +19,7 @@ var ( cfg *config.Config providerSvc provider.CloudProvider dnsSvc *dns.CloudflareDNSProvider + labManager lab.Manager logLevel string ) @@ -33,7 +35,7 @@ func NewRootCmd() *cobra.Command { initConfig() initProvider() initDNS() - + initLabManager() return nil }, } @@ -84,6 +86,9 @@ func initConfig() { viper.AutomaticEnv() viper.SetEnvPrefix(strings.ToUpper(config.ToolName)) + viper.SetDefault("storage.path", filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, "labs.db")) + viper.SetDefault("storage.bucket", "labs") + if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { fmt.Fprintf(os.Stderr, "Error reading config: %v\n", err) @@ -125,6 +130,15 @@ func initDNS() { } } +func initLabManager() { + var err error + labManager, err = lab.NewManager(providerSvc, cfg) + if err != nil { + fmt.Fprintf(os.Stderr, "Error initializing lab manager: %v\n", err) + os.Exit(1) + } +} + func Execute() error { return NewRootCmd().Execute() } diff --git a/cmd/sync.go b/cmd/sync.go index e4e1b62..07b0c66 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -12,7 +12,7 @@ func NewSyncCmd() *cobra.Command { Short: "Sync labs", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - if err := providerSvc.SyncLabs(); err != nil { + if err := labManager.SyncLabs(); err != nil { return fmt.Errorf("error syncing labs: %w", err) } return nil diff --git a/containers/Containerfile.fedora b/containers/Containerfile.fedora new file mode 100644 index 0000000..d4f5fcf --- /dev/null +++ b/containers/Containerfile.fedora @@ -0,0 +1,67 @@ +# Start with Fedora base image +FROM registry.fedoraproject.org/fedora:41-x86_64 + +# Set labels for better container management +LABEL maintainer="Pavel Anni " +LABEL description="AIStor CLI with storctl, Ansible, and Hetzner Cloud CLI" + +# Install basic tools and dependencies +RUN dnf update -y \ + && dnf install -y dnf-plugins-core \ + && dnf upgrade -y fedora-gpg-keys \ + && dnf install -y \ + librepo \ + libxcrypt-compat \ + git \ + curl \ + wget \ + unzip \ + python3 \ + python3-pip \ + jq \ + vim \ + which \ + tar \ + openssl \ + && dnf clean all + +# Install kubectl +RUN curl -LO "https://dl.k8s.io/release/$(curl -L -s https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl" && \ + chmod +x kubectl && \ + mv kubectl /usr/local/bin/ + +# Install Krew +RUN (cd /tmp && curl -fsSL https://github.com/kubernetes-sigs/krew/releases/latest/download/krew-linux_amd64.tar.gz | tar xz) && \ + /tmp/krew-linux_amd64 install krew && \ + echo "export PATH=\$PATH:~/.krew/bin" >> /root/.bashrc + + +# Install Helm +RUN curl -fsSL https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash + +# Install Ansible and Kubernetes Python library +RUN pip3 install --no-cache-dir ansible kubernetes + +# Install Hetzner Cloud CLI +RUN wget -O /tmp/hcloud.tar.gz https://github.com/hetznercloud/cli/releases/latest/download/hcloud-linux-amd64.tar.gz && \ + tar -xzf /tmp/hcloud.tar.gz -C /tmp && \ + mv /tmp/hcloud /usr/local/bin/ && \ + chmod +x /usr/local/bin/hcloud && \ + rm -f /tmp/hcloud.tar.gz + +# Create working directory +WORKDIR /workspace + +# Set environment variables +ENV ANSIBLE_HOST_KEY_CHECKING=False + +# Install DirectPV plugin for kubectl +ENV PATH="${PATH}:/root/.krew/bin" +RUN kubectl krew install directpv + +# Install storctl +COPY ./dist/storctl_linux_amd64_v1/storctl /usr/local/bin/storctl +RUN chmod +x /usr/local/bin/storctl + +# Default command +CMD ["/bin/bash"] diff --git a/internal/config/config.go b/internal/config/config.go index b9d0c6f..a1d9e24 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,8 @@ package config import ( "fmt" + "os" + "path/filepath" "github.com/spf13/viper" ) @@ -9,11 +11,17 @@ import ( type Config struct { Provider ProviderConfig `mapstructure:"provider"` DNS DNSConfig `mapstructure:"dns"` + Storage StorageConfig `mapstructure:"storage"` Email string `mapstructure:"email" yaml:"email"` Organization string `mapstructure:"organization" yaml:"organization"` Owner string `mapstructure:"owner" yaml:"owner"` - OutputFormat string - LogLevel string + OutputFormat string `mapstructure:"output_format" yaml:"output_format"` + LogLevel string `mapstructure:"log_level" yaml:"log_level"` +} + +type StorageConfig struct { + Path string `mapstructure:"path" yaml:"path"` + Bucket string `mapstructure:"bucket" yaml:"bucket"` } type ProviderConfig struct { @@ -83,7 +91,6 @@ func LoadConfig(configPath string) (*Config, error) { } func setDefaults(v *viper.Viper) { - // Add any default values here - // Example: - // v.SetDefault("some.default.value", "default") + v.SetDefault("storage.path", filepath.Join(os.Getenv("HOME"), DefaultConfigDir, DefaultLabStorageFile)) + v.SetDefault("storage.bucket", DefaultLabBucket) } diff --git a/internal/config/constants.go b/internal/config/constants.go index a244f23..d1ce239 100644 --- a/internal/config/constants.go +++ b/internal/config/constants.go @@ -22,14 +22,20 @@ const ( // DefaultTemplateDir is the default directory for lab templates DefaultTemplateDir = "templates" - // KeysDir is the subdirectory name for storing SSH keys - KeysDir = "keys" + // DefaultKeysDir is the default directory for storing SSH keys + DefaultKeysDir = "keys" // ConfigFileName is the name of the configuration file ConfigFileName = "config.yaml" // DefaultAdminUser is the default admin user DefaultAdminUser = "ansible" + + // DefaultLabBucket is the default bucket for storing labs + DefaultLabBucket = "labs" + + // DefaultLabStorageFile is the default file for storing labs + DefaultLabStorageFile = "labs.db" ) // Provider related constants @@ -43,6 +49,9 @@ const ( // DefaultLocation is the default location DefaultLocation = "nbg1" + // DefaultAdminKeyName is the default SSH key name + DefaultAdminKeyName = "aistor-admin" + // DefaultImage is the default image DefaultImage = "ubuntu-24.04" diff --git a/internal/lab/lab.go b/internal/lab/lab.go new file mode 100644 index 0000000..671c5f2 --- /dev/null +++ b/internal/lab/lab.go @@ -0,0 +1,398 @@ +package lab + +import ( + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/logger" + "github.com/pavelanni/storctl/internal/provider" + "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/ssh" + "github.com/pavelanni/storctl/internal/types" + "github.com/pavelanni/storctl/internal/util/serverchecker" + "go.etcd.io/bbolt" +) + +type Manager interface { + Create(lab *types.Lab) error + Get(labName string) (*types.Lab, error) + List() ([]*types.Lab, error) + Delete(labName string, force bool) error + SyncLabs() error +} + +type ManagerSvc struct { + provider provider.CloudProvider + sshManager *ssh.Manager + storage *Storage + logger *slog.Logger +} + +type Storage struct { + db *bbolt.DB + labBucket []byte +} + +var DefaultManager *ManagerSvc + +var _ Manager = (*ManagerSvc)(nil) + +func NewBboltDB(path string) (*bbolt.DB, error) { + db, err := bbolt.Open(path, 0600, nil) + if err != nil { + return nil, err + } + return db, nil +} + +func NewLabStorage(cfg *config.Config) (*Storage, error) { + db, err := NewBboltDB(cfg.Storage.Path) + if err != nil { + return nil, err + } + + // Create bucket if it doesn't exist + err = db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(cfg.Storage.Bucket)) + return err + }) + if err != nil { + return nil, fmt.Errorf("create bucket: %w", err) + } + + return &Storage{ + db: db, + labBucket: []byte(cfg.Storage.Bucket), + }, nil +} + +func NewManager(provider provider.CloudProvider, cfg *config.Config) (Manager, error) { + sshManager := ssh.NewManager(cfg) + logLevel := logger.ParseLevel(cfg.LogLevel) + fmt.Println("logLevel", logLevel) + logger := logger.NewLogger(logLevel) + storage, err := NewLabStorage(cfg) + if err != nil { + return nil, err + } + return &ManagerSvc{ + storage: storage, + provider: provider, + sshManager: sshManager, + logger: logger, + }, nil +} + +func (m *ManagerSvc) Create(lab *types.Lab) error { + labAdminKeyName := strings.Join([]string{lab.ObjectMeta.Name, "admin"}, "-") + sshKeys := make([]*types.SSHKey, 2) // 2 keys: default admin key and lab admin key + sshKeys[0] = &types.SSHKey{ // default admin key is already on the cloud + ObjectMeta: types.ObjectMeta{ + Name: config.DefaultAdminKeyName, + }, + } + labAdminPublicKey, err := m.sshManager.CreateLocalKeyPair(labAdminKeyName) + if err != nil { + return err + } + labAdminCloudKey, err := m.provider.CreateSSHKey(options.SSHKeyCreateOpts{ + Name: labAdminKeyName, + PublicKey: labAdminPublicKey, + }) + if err != nil { + return err + } + sshKeys[1] = labAdminCloudKey + + ttl := lab.Spec.TTL + if ttl == "" { + ttl = config.DefaultTTL + } + // Create servers + specServers := lab.Spec.Servers + servers := make([]*types.Server, 0) + for _, serverSpec := range specServers { + s := &types.Server{ + TypeMeta: types.TypeMeta{ + Kind: "Server", + APIVersion: "v1", + }, + ObjectMeta: types.ObjectMeta{ + Name: strings.Join([]string{lab.ObjectMeta.Name, serverSpec.Name}, "-"), + Labels: lab.ObjectMeta.Labels, + }, + Spec: types.ServerSpec{ + Location: lab.Spec.Location, + Provider: lab.Spec.Provider, + ServerType: serverSpec.ServerType, + TTL: ttl, + Image: serverSpec.Image, + }, + } + result, err := m.provider.CreateServer(options.ServerCreateOpts{ + Name: s.ObjectMeta.Name, + Type: s.Spec.ServerType, + Image: s.Spec.Image, + Location: s.Spec.Location, + Provider: s.Spec.Provider, + SSHKeys: sshKeys, + }) + if err != nil { + return err + } + servers = append(servers, result) + } + + // Wait for servers to be ready + timeout := 30 * time.Minute + attempts := 20 + results, err := serverchecker.CheckServers(servers, m.logger, timeout, attempts) + if err != nil { + return err + } + for _, result := range results { + fmt.Printf("Server %s: Ready: %v\n", result.Server.ObjectMeta.Name, result.Ready) + if !result.Ready { + return fmt.Errorf("server %s not ready", result.Server.ObjectMeta.Name) + } + } + + // Create volumes + volumes := lab.Spec.Volumes + for _, volumeSpec := range volumes { + if !volumeSpec.Automount { // if not specified, default to false + volumeSpec.Automount = config.DefaultVolumeAutomount + } + if volumeSpec.Format == "" { // if not specified, default to xfs + volumeSpec.Format = config.DefaultVolumeFormat + } + v := &types.Volume{ + TypeMeta: types.TypeMeta{ + Kind: "Volume", + APIVersion: "v1", + }, + ObjectMeta: types.ObjectMeta{ + Name: strings.Join([]string{lab.ObjectMeta.Name, volumeSpec.Name}, "-"), + Labels: lab.ObjectMeta.Labels, + }, + Spec: types.VolumeSpec{ + Size: volumeSpec.Size, + ServerName: strings.Join([]string{lab.ObjectMeta.Name, volumeSpec.Server}, "-"), + Automount: volumeSpec.Automount, + Format: volumeSpec.Format, + }, + } + _, err := m.provider.CreateVolume(options.VolumeCreateOpts{ + Name: v.ObjectMeta.Name, + Size: v.Spec.Size, + ServerName: v.Spec.ServerName, + Automount: v.Spec.Automount, + Format: v.Spec.Format, + }) + if err != nil { + return err + } + } + return nil +} + +func (m *ManagerSvc) Get(labName string) (*types.Lab, error) { + lab, err := m.storage.Get(labName) + if err == nil { + return lab, nil + } + lab, err = m.syncLabFromCloud(labName) + if err != nil { + return nil, err + } + return lab, nil +} + +func (m *ManagerSvc) List() ([]*types.Lab, error) { + var labs []*types.Lab + + err := m.storage.db.View(func(tx *bbolt.Tx) error { + b := tx.Bucket(m.storage.labBucket) + if b == nil { + return fmt.Errorf("labs bucket not found in database") + } + + return b.ForEach(func(k, v []byte) error { + var lab types.Lab + if err := json.Unmarshal(v, &lab); err != nil { + return err + } + labs = append(labs, &lab) + return nil + }) + }) + + return labs, err +} + +func (m *ManagerSvc) SyncLabs() error { + labsMap := make(map[string]*types.Lab) + allServers, err := m.provider.AllServers() + if err != nil { + return err + } + // collect unique lab names + for _, server := range allServers { + if server.Labels["lab_name"] != "" { + labsMap[server.Labels["lab_name"]] = &types.Lab{} + } + } + for labName := range labsMap { + lab, err := m.getLabFromCloud(labName) + if err != nil { + return err + } + labsMap[labName] = lab + } + return m.storage.db.Update(func(tx *bbolt.Tx) error { + b := tx.Bucket(m.storage.labBucket) + + // Clear existing data + if err := b.ForEach(func(k, v []byte) error { + return b.Delete(k) + }); err != nil { + return err + } + + // Store new data + for labName, lab := range labsMap { + data, err := json.Marshal(lab) + if err != nil { + return err + } + if err := b.Put([]byte(labName), data); err != nil { + return err + } + } + return nil + }) +} + +func (m *ManagerSvc) Delete(labName string, force bool) error { + lab, err := m.Get(labName) + if err != nil { + return err + } + // delete volumes first + for _, volume := range lab.Status.Volumes { + m.logger.Debug("deleting volume", "volume", volume.ObjectMeta.Name) + status := m.provider.DeleteVolume(volume.ObjectMeta.Name, force) + if status.Error != nil { + m.logger.Error("failed to delete volume", "volume", volume.ObjectMeta.Name, "error", status.Error) + } + } + // delete servers + for _, server := range lab.Status.Servers { + // delete server's ssh keys + for _, sshKeyName := range server.Spec.SSHKeyNames { + m.logger.Debug("deleting ssh key", "key", sshKeyName) + status := m.provider.DeleteSSHKey(sshKeyName, force) + if status.Error != nil { + m.logger.Error("failed to delete ssh key", "key", sshKeyName, "error", status.Error) + } + } + m.logger.Debug("deleting server", "server", server.ObjectMeta.Name) + status := m.provider.DeleteServer(server.ObjectMeta.Name, force) + if status.Error != nil { + m.logger.Error("failed to delete server", "server", server.ObjectMeta.Name, "error", status.Error) + } + } + + // delete lab from storage + return m.storage.db.Update(func(tx *bbolt.Tx) error { + return tx.Bucket(m.storage.labBucket).Delete([]byte(labName)) + }) +} + +func (s *Storage) Get(labName string) (*types.Lab, error) { + var lab *types.Lab + + err := s.db.View(func(tx *bbolt.Tx) error { + b := tx.Bucket(s.labBucket) + data := b.Get([]byte(labName)) + if data == nil { + return fmt.Errorf("lab %s not found", labName) + } + + lab = &types.Lab{} + if err := json.Unmarshal(data, lab); err != nil { + return err + } + return nil + }) + + return lab, err +} + +func (s *Storage) Save(lab *types.Lab) error { + return s.db.Update(func(tx *bbolt.Tx) error { + b := tx.Bucket(s.labBucket) + data, err := json.Marshal(lab) + if err != nil { + return err + } + return b.Put([]byte(lab.Name), data) + }) +} + +func (m *ManagerSvc) syncLabFromCloud(labName string) (*types.Lab, error) { + lab, err := m.getLabFromCloud(labName) + if err != nil { + return nil, err + } + if err := m.storage.Save(lab); err != nil { + m.logger.Warn("failed to save lab to storage", "error", err) + } + return lab, nil +} + +func (m *ManagerSvc) getLabFromCloud(labName string) (*types.Lab, error) { + lab := &types.Lab{ + TypeMeta: types.TypeMeta{ + APIVersion: "v1", + Kind: "Lab", + }, + ObjectMeta: types.ObjectMeta{ + Name: labName, + }, + } + + servers, err := m.provider.ListServers(options.ServerListOpts{ + ListOpts: options.ListOpts{ + LabelSelector: "lab_name=" + labName, + }, + }) + if err != nil { + return nil, err + } + volumes, err := m.provider.ListVolumes(options.VolumeListOpts{ + ListOpts: options.ListOpts{ + LabelSelector: "lab_name=" + labName, + }, + }) + if err != nil { + return nil, err + } + lab.Status.Servers = append(lab.Status.Servers, servers...) + lab.Status.Volumes = append(lab.Status.Volumes, volumes...) + // Add labels from the first server + if len(servers) > 0 { + lab.ObjectMeta.Labels = servers[0].ObjectMeta.Labels + } + lab.Status.State = servers[0].Status.Status + lab.Status.Owner = servers[0].Status.Owner + lab.Status.Created = servers[0].Status.Created + lab.Status.DeleteAfter = servers[0].Status.DeleteAfter + lab.Spec.Location = servers[0].Spec.Location + lab.Spec.Provider = servers[0].Spec.Provider + return lab, nil +} diff --git a/internal/lab/mock/manager.go b/internal/lab/mock/manager.go new file mode 100644 index 0000000..ffcd816 --- /dev/null +++ b/internal/lab/mock/manager.go @@ -0,0 +1,36 @@ +package mock + +import ( + "github.com/pavelanni/storctl/internal/lab" + "github.com/pavelanni/storctl/internal/types" +) + +// Manager implements lab.ManagerSvc for testing +type Manager struct { + *lab.ManagerSvc // Embed the ManagerSvc + ListFunc func() ([]*types.Lab, error) + GetFunc func(name string) (*types.Lab, error) + GetFromCloudFunc func(name string) (*types.Lab, error) + CreateFunc func(lab *types.Lab) error + DeleteFunc func(name string, force bool) error +} + +func (m *Manager) List() ([]*types.Lab, error) { + return m.ListFunc() +} + +func (m *Manager) Get(name string) (*types.Lab, error) { + return m.GetFunc(name) +} + +func (m *Manager) Create(lab *types.Lab) error { + return m.CreateFunc(lab) +} + +func (m *Manager) Delete(name string, force bool) error { + return m.DeleteFunc(name, force) +} + +func (m *Manager) GetFromCloud(name string) (*types.Lab, error) { + return m.GetFromCloudFunc(name) +} diff --git a/internal/provider/hetzner/lab.go b/internal/provider/hetzner/lab.go deleted file mode 100644 index f3ad9a3..0000000 --- a/internal/provider/hetzner/lab.go +++ /dev/null @@ -1,239 +0,0 @@ -package hetzner - -import ( - "encoding/json" - - "github.com/pavelanni/storctl/internal/provider/options" - "github.com/pavelanni/storctl/internal/types" - "go.etcd.io/bbolt" -) - -func (p *HetznerProvider) CreateLab(name string, template string) error { - return nil -} - -func (p *HetznerProvider) GetLabFromCloud(labName string) (*types.Lab, error) { - lab := &types.Lab{ - TypeMeta: types.TypeMeta{ - APIVersion: "v1", - Kind: "Lab", - }, - ObjectMeta: types.ObjectMeta{ - Name: labName, - }, - } - - servers, err := p.ListServers(options.ServerListOpts{ - ListOpts: options.ListOpts{ - LabelSelector: "lab_name=" + labName, - }, - }) - if err != nil { - return nil, err - } - volumes, err := p.ListVolumes(options.VolumeListOpts{ - ListOpts: options.ListOpts{ - LabelSelector: "lab_name=" + labName, - }, - }) - if err != nil { - return nil, err - } - lab.Status.Servers = append(lab.Status.Servers, servers...) - lab.Status.Volumes = append(lab.Status.Volumes, volumes...) - // Add labels from the first server - if len(servers) > 0 { - lab.ObjectMeta.Labels = servers[0].ObjectMeta.Labels - } - lab.Status.Status = servers[0].Status.Status - lab.Status.Owner = servers[0].Status.Owner - lab.Status.Created = servers[0].Status.Created - lab.Status.DeleteAfter = servers[0].Status.DeleteAfter - lab.Spec.Location = servers[0].Spec.Location - lab.Spec.Provider = servers[0].Spec.Provider - return lab, nil -} - -func (p *HetznerProvider) GetLab(labName string) (*types.Lab, error) { - var lab *types.Lab - - err := p.db.View(func(tx *bbolt.Tx) error { - b := tx.Bucket(p.labBucket) - data := b.Get([]byte(labName)) - if data == nil { - // If not in cache, fetch from cloud - var err error - lab, err = p.GetLabFromCloud(labName) - if err == nil { - // store in cache - data, err := json.Marshal(lab) - if err != nil { - return err - } - return b.Put([]byte(labName), data) - } - return err - } - - lab = &types.Lab{} - if err := json.Unmarshal(data, lab); err != nil { - return err - } - return nil - }) - - return lab, err -} - -func (p *HetznerProvider) SyncLabs() error { - p.logger.Debug("syncing labs") - labsMap := make(map[string]*types.Lab) - allServers, err := p.AllServers() - if err != nil { - return err - } - // collect unique lab names - for _, server := range allServers { - if server.Labels["lab_name"] != "" { - labsMap[server.Labels["lab_name"]] = &types.Lab{} - } - } - for labName := range labsMap { - lab, err := p.GetLabFromCloud(labName) - if err != nil { - return err - } - labsMap[labName] = lab - } - return p.db.Update(func(tx *bbolt.Tx) error { - b := tx.Bucket(p.labBucket) - - // Clear existing data - if err := b.ForEach(func(k, v []byte) error { - return b.Delete(k) - }); err != nil { - return err - } - - // Store new data - for labName, lab := range labsMap { - data, err := json.Marshal(lab) - if err != nil { - return err - } - if err := b.Put([]byte(labName), data); err != nil { - return err - } - } - return nil - }) -} - -func (p *HetznerProvider) ListLabs(opts options.LabListOpts) ([]*types.Lab, error) { - var labs []*types.Lab - - err := p.db.View(func(tx *bbolt.Tx) error { - b := tx.Bucket(p.labBucket) - return b.ForEach(func(k, v []byte) error { - var lab types.Lab - if err := json.Unmarshal(v, &lab); err != nil { - return err - } - labs = append(labs, &lab) - return nil - }) - }) - - return labs, err -} - -func (p *HetznerProvider) DeleteLab(labName string, force bool) *types.LabDeleteStatus { - lab, err := p.GetLab(labName) - if err != nil { - p.logger.Error("failed to get lab details", - "lab", labName, - "error", err) - return &types.LabDeleteStatus{ - Error: err, - } - } - - // Get all SSH keys - sshKeys, err := p.AllSSHKeys() - if err != nil { - p.logger.Error("failed to get SSH keys", - "lab", labName, - "error", err) - return &types.LabDeleteStatus{ - Error: err, - } - } - - p.logger.Debug("deleting lab", - "lab", lab.Name, - "servers", len(lab.Status.Servers), - "volumes", len(lab.Status.Volumes)) - - // Delete volumes first - for _, volume := range lab.Status.Volumes { - p.logger.Debug("deleting volume", - "volume", volume.Name) - if status := p.DeleteVolume(volume.Name, force); status.Error != nil { - p.logger.Error("failed to delete volume", - "volume", volume.Name, - "error", status.Error) - return &types.LabDeleteStatus{ - Error: status.Error, - } - } - } - - // Delete servers - for _, server := range lab.Status.Servers { - p.logger.Debug("deleting server", - "server", server.Name) - if status := p.DeleteServer(server.Name, force); status.Error != nil { - p.logger.Error("failed to delete server", - "server", server.Name, - "error", status.Error) - return &types.LabDeleteStatus{ - Error: status.Error, - } - } - } - - // Delete SSH keys associated with this lab - for _, sshKey := range sshKeys { - if sshKey.Labels["lab_name"] == labName { - p.logger.Debug("deleting SSH key", - "key", sshKey.Name) - if status := p.DeleteSSHKey(sshKey.Name, force); status.Error != nil { - p.logger.Error("failed to delete SSH key", - "key", sshKey.Name, - "error", status.Error) - return &types.LabDeleteStatus{ - Error: status.Error, - } - } - } - } - - p.logger.Debug("lab deletion from the cloud completed successfully", - "lab", labName) - - // Delete from the database - if err := p.db.Update(func(tx *bbolt.Tx) error { - return tx.Bucket(p.labBucket).Delete([]byte(labName)) - }); err != nil { - p.logger.Error("failed to delete lab from the database", - "lab", labName, - "error", err) - return &types.LabDeleteStatus{ - Error: err, - } - } - - return &types.LabDeleteStatus{ - Deleted: true, - } -} diff --git a/internal/provider/hetzner/provider.go b/internal/provider/hetzner/provider.go index 6290026..461c670 100644 --- a/internal/provider/hetzner/provider.go +++ b/internal/provider/hetzner/provider.go @@ -7,15 +7,12 @@ import ( "github.com/hetznercloud/hcloud-go/v2/hcloud" "github.com/pavelanni/storctl/internal/config" "github.com/pavelanni/storctl/internal/logger" - "go.etcd.io/bbolt" ) type HetznerProvider struct { - Client *hcloud.Client - config *config.Config - logger *slog.Logger - db *bbolt.DB - labBucket []byte + Client *hcloud.Client + config *config.Config + logger *slog.Logger } func New(cfg *config.Config) (*HetznerProvider, error) { @@ -30,25 +27,9 @@ func New(cfg *config.Config) (*HetznerProvider, error) { client := hcloud.NewClient(hcloud.WithToken(token)) p := &HetznerProvider{ - Client: client, - config: cfg, - logger: providerLogger, - labBucket: []byte("labs"), - } - // Open the database - db, err := bbolt.Open("labs.db", 0600, nil) - if err != nil { - return nil, fmt.Errorf("failed to open db: %w", err) - } - p.db = db - - // Create the bucket - err = db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(p.labBucket) - return err - }) - if err != nil { - return nil, fmt.Errorf("failed to create bucket: %w", err) + Client: client, + config: cfg, + logger: providerLogger, } return p, nil diff --git a/internal/provider/hetzner/server.go b/internal/provider/hetzner/server.go index 1513ade..c013d76 100644 --- a/internal/provider/hetzner/server.go +++ b/internal/provider/hetzner/server.go @@ -6,6 +6,7 @@ import ( "time" "github.com/hetznercloud/hcloud-go/v2/hcloud" + "github.com/pavelanni/storctl/internal/config" "github.com/pavelanni/storctl/internal/provider/options" "github.com/pavelanni/storctl/internal/types" "github.com/pavelanni/storctl/internal/util/timeutil" @@ -147,6 +148,25 @@ func (p *HetznerProvider) DeleteServer(serverName string, force bool) *types.Ser } } +func (p *HetznerProvider) ServerToCreateOpts(server *types.Server) (options.ServerCreateOpts, error) { + sshKeys, err := p.KeyNamesToSSHKeys(server.Spec.SSHKeyNames, options.SSHKeyCreateOpts{ + Labels: server.ObjectMeta.Labels, + }) + if err != nil { + return options.ServerCreateOpts{}, err + } + cloudInitUserData := fmt.Sprintf(config.DefaultCloudInitUserData, sshKeys[0].Spec.PublicKey) + return options.ServerCreateOpts{ + Name: server.ObjectMeta.Name, + Type: server.Spec.ServerType, + Image: server.Spec.Image, + Location: server.Spec.Location, + Provider: "hetzner", + SSHKeys: sshKeys, + UserData: cloudInitUserData, + }, nil +} + // mapServer converts a Hetzner-specific server to our generic Server type func (p *HetznerProvider) mapServer(s *hcloud.Server) *types.Server { if s == nil { diff --git a/internal/provider/hetzner/sshkey.go b/internal/provider/hetzner/sshkey.go index 4e4aca0..5f4f650 100644 --- a/internal/provider/hetzner/sshkey.go +++ b/internal/provider/hetzner/sshkey.go @@ -6,7 +6,9 @@ import ( "time" "github.com/hetznercloud/hcloud-go/v2/hcloud" + "github.com/pavelanni/storctl/internal/config" "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/ssh" "github.com/pavelanni/storctl/internal/types" "github.com/pavelanni/storctl/internal/util/timeutil" ) @@ -38,6 +40,9 @@ func (p *HetznerProvider) GetSSHKey(name string) (*types.SSHKey, error) { "key", name) return nil, fmt.Errorf("SSH key not found") } + p.logger.Debug("SSH key found", + "key", name, + "public_key", sshKey.PublicKey) return mapSSHKey(sshKey), nil } @@ -68,14 +73,14 @@ func (p *HetznerProvider) DeleteSSHKey(name string, force bool) *types.SSHKeyDel } } - keyExists, err := p.KeyExists(name) + keyExists, err := p.CloudKeyExists(name) if err != nil { return &types.SSHKeyDeleteStatus{ Error: err, } } if !keyExists { - p.logger.Debug("SSH key not found, skipping", + p.logger.Debug("SSH key not found on the cloud, skipping", "key", name) return &types.SSHKeyDeleteStatus{ Deleted: true, @@ -93,7 +98,7 @@ func (p *HetznerProvider) DeleteSSHKey(name string, force bool) *types.SSHKeyDel if !force { if deleteAfterStr, ok := sshKey.Labels["delete_after"]; ok { - deleteAfter, err := time.Parse(time.RFC3339, deleteAfterStr) + deleteAfter := timeutil.ParseDeleteAfter(deleteAfterStr) if err == nil && time.Now().UTC().Before(deleteAfter) { p.logger.Warn("key not ready for deletion", "key", name, @@ -105,12 +110,12 @@ func (p *HetznerProvider) DeleteSSHKey(name string, force bool) *types.SSHKeyDel } } - p.logger.Debug("deleting SSH key", + p.logger.Debug("deleting cloud SSH key", "key", name) _, err = p.Client.SSHKey.Delete(context.Background(), sshKey) if err != nil { - p.logger.Error("failed to delete SSH key", + p.logger.Error("failed to delete cloud SSH key", "key", name) } return &types.SSHKeyDeleteStatus{ @@ -118,12 +123,56 @@ func (p *HetznerProvider) DeleteSSHKey(name string, force bool) *types.SSHKeyDel } } -func (p *HetznerProvider) KeyExists(name string) (bool, error) { - sshKey, _, err := p.Client.SSHKey.GetByName(context.Background(), name) +func (p *HetznerProvider) CloudKeyExists(name string) (bool, error) { + // check if the cloud key exists + cloudKey, _, err := p.Client.SSHKey.GetByName(context.Background(), name) if err != nil { return false, fmt.Errorf("failed to check SSH key existence: %w", err) } - return sshKey != nil, nil + + return cloudKey != nil, nil +} + +// KeyNamesToSSHKeys converts a list of SSH key names to a list of SSH keys +// It will upload local SSH keys to the cloud if they don't exist +// It adds the default admin key to the list +func (p *HetznerProvider) KeyNamesToSSHKeys(keyNames []string, opts options.SSHKeyCreateOpts) ([]*types.SSHKey, error) { + sshManager := ssh.NewManager(p.config) + sshKeys := make([]*types.SSHKey, 0) + adminKey, err := p.GetSSHKey(config.DefaultAdminKeyName) + if err != nil { + return nil, fmt.Errorf("failed to get admin key: %w", err) + } + sshKeys = append(sshKeys, adminKey) + + for _, keyName := range keyNames { + cloudKeyExists, err := p.CloudKeyExists(keyName) + if err != nil { + return nil, err + } + if !cloudKeyExists { + // check if the key exists locally + localKeyExists, err := sshManager.LocalKeyExists(keyName) + if err != nil { + return nil, err + } + if !localKeyExists { + fmt.Printf("SSH key %s not found locally, skipping it\n", keyName) + continue + } + pubKey, err := sshManager.ReadLocalPublicKey(keyName) + if err != nil { + return nil, fmt.Errorf("failed to read local public key: %w", err) + } + opts.PublicKey = pubKey + newKey, err := p.CreateSSHKey(opts) + if err != nil { + return nil, fmt.Errorf("failed to create SSH key: %w", err) + } + sshKeys = append(sshKeys, newKey) + } + } + return sshKeys, nil } func mapSSHKey(sk *hcloud.SSHKey) *types.SSHKey { diff --git a/internal/provider/mock/mock_provider.go b/internal/provider/mock/mock_provider.go new file mode 100644 index 0000000..9910059 --- /dev/null +++ b/internal/provider/mock/mock_provider.go @@ -0,0 +1,200 @@ +package mock + +import ( + "github.com/pavelanni/storctl/internal/provider" + "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/types" +) + +// MockProvider implements the CloudProvider interface for testing +type MockProvider struct { + // Function fields to customize behavior + CreateServerFunc func(opts options.ServerCreateOpts) (*types.Server, error) + GetServerFunc func(name string) (*types.Server, error) + ListServersFunc func(opts options.ServerListOpts) ([]*types.Server, error) + AllServersFunc func() ([]*types.Server, error) + DeleteServerFunc func(name string, force bool) *types.ServerDeleteStatus + ServerToCreateOptsFunc func(server *types.Server) (options.ServerCreateOpts, error) + CreateVolumeFunc func(opts options.VolumeCreateOpts) (*types.Volume, error) + GetVolumeFunc func(name string) (*types.Volume, error) + ListVolumesFunc func(opts options.VolumeListOpts) ([]*types.Volume, error) + AllVolumesFunc func() ([]*types.Volume, error) + DeleteVolumeFunc func(name string, force bool) *types.VolumeDeleteStatus + CreateLabOnCloudFunc func(lab *types.Lab) error + GetLabFromCloudFunc func(name string) (*types.Lab, error) + ListLabsFunc func(opts options.LabListOpts) ([]*types.Lab, error) + DeleteLabFromCloudFunc func(name string, force bool) *types.LabDeleteStatus + SyncLabsFunc func() error + AllSSHKeysFunc func() ([]*types.SSHKey, error) + CreateSSHKeyFunc func(opts options.SSHKeyCreateOpts) (*types.SSHKey, error) + DeleteSSHKeyFunc func(name string, force bool) *types.SSHKeyDeleteStatus + GetSSHKeyFunc func(name string) (*types.SSHKey, error) + CloudKeyExistsFunc func(name string) (bool, error) + ListSSHKeysFunc func(opts options.SSHKeyListOpts) ([]*types.SSHKey, error) + KeyNamesToSSHKeysFunc func(keyNames []string, opts options.SSHKeyCreateOpts) ([]*types.SSHKey, error) +} + +// Ensure MockProvider implements CloudProvider interface +var _ provider.CloudProvider = &MockProvider{} + +// Implementation of interface methods +func (m *MockProvider) CreateServer(opts options.ServerCreateOpts) (*types.Server, error) { + if m.CreateServerFunc != nil { + return m.CreateServerFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) GetServer(name string) (*types.Server, error) { + if m.GetServerFunc != nil { + return m.GetServerFunc(name) + } + return nil, nil +} + +func (m *MockProvider) ListServers(opts options.ServerListOpts) ([]*types.Server, error) { + if m.ListServersFunc != nil { + return m.ListServersFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) AllServers() ([]*types.Server, error) { + if m.AllServersFunc != nil { + return m.AllServersFunc() + } + return nil, nil +} + +func (m *MockProvider) DeleteServer(name string, force bool) *types.ServerDeleteStatus { + if m.DeleteServerFunc != nil { + return m.DeleteServerFunc(name, force) + } + return &types.ServerDeleteStatus{} +} + +func (m *MockProvider) ServerToCreateOpts(server *types.Server) (options.ServerCreateOpts, error) { + if m.ServerToCreateOptsFunc != nil { + return m.ServerToCreateOptsFunc(server) + } + return options.ServerCreateOpts{}, nil +} + +func (m *MockProvider) CreateVolume(opts options.VolumeCreateOpts) (*types.Volume, error) { + if m.CreateVolumeFunc != nil { + return m.CreateVolumeFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) GetVolume(name string) (*types.Volume, error) { + if m.GetVolumeFunc != nil { + return m.GetVolumeFunc(name) + } + return nil, nil +} + +func (m *MockProvider) ListVolumes(opts options.VolumeListOpts) ([]*types.Volume, error) { + if m.ListVolumesFunc != nil { + return m.ListVolumesFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) AllVolumes() ([]*types.Volume, error) { + if m.AllVolumesFunc != nil { + return m.AllVolumesFunc() + } + return nil, nil +} + +func (m *MockProvider) DeleteVolume(name string, force bool) *types.VolumeDeleteStatus { + if m.DeleteVolumeFunc != nil { + return m.DeleteVolumeFunc(name, force) + } + return &types.VolumeDeleteStatus{} +} + +func (m *MockProvider) CreateLabOnCloud(lab *types.Lab) error { + if m.CreateLabOnCloudFunc != nil { + return m.CreateLabOnCloudFunc(lab) + } + return nil +} + +func (m *MockProvider) GetLabFromCloud(name string) (*types.Lab, error) { + if m.GetLabFromCloudFunc != nil { + return m.GetLabFromCloudFunc(name) + } + return nil, nil +} + +func (m *MockProvider) ListLabs(opts options.LabListOpts) ([]*types.Lab, error) { + if m.ListLabsFunc != nil { + return m.ListLabsFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) DeleteLabFromCloud(name string, force bool) *types.LabDeleteStatus { + if m.DeleteLabFromCloudFunc != nil { + return m.DeleteLabFromCloudFunc(name, force) + } + return &types.LabDeleteStatus{} +} + +func (m *MockProvider) SyncLabs() error { + if m.SyncLabsFunc != nil { + return m.SyncLabsFunc() + } + return nil +} + +func (m *MockProvider) AllSSHKeys() ([]*types.SSHKey, error) { + if m.AllSSHKeysFunc != nil { + return m.AllSSHKeysFunc() + } + return nil, nil +} + +func (m *MockProvider) CreateSSHKey(opts options.SSHKeyCreateOpts) (*types.SSHKey, error) { + if m.CreateSSHKeyFunc != nil { + return m.CreateSSHKeyFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) DeleteSSHKey(name string, force bool) *types.SSHKeyDeleteStatus { + if m.DeleteSSHKeyFunc != nil { + return m.DeleteSSHKeyFunc(name, force) + } + return &types.SSHKeyDeleteStatus{} +} + +func (m *MockProvider) GetSSHKey(name string) (*types.SSHKey, error) { + if m.GetSSHKeyFunc != nil { + return m.GetSSHKeyFunc(name) + } + return nil, nil +} + +func (m *MockProvider) CloudKeyExists(name string) (bool, error) { + if m.CloudKeyExistsFunc != nil { + return m.CloudKeyExistsFunc(name) + } + return false, nil +} + +func (m *MockProvider) ListSSHKeys(opts options.SSHKeyListOpts) ([]*types.SSHKey, error) { + if m.ListSSHKeysFunc != nil { + return m.ListSSHKeysFunc(opts) + } + return nil, nil +} + +func (m *MockProvider) KeyNamesToSSHKeys(keyNames []string, opts options.SSHKeyCreateOpts) ([]*types.SSHKey, error) { + if m.KeyNamesToSSHKeysFunc != nil { + return m.KeyNamesToSSHKeysFunc(keyNames, opts) + } + return nil, nil +} diff --git a/internal/provider/mock/mock_provider_test.go b/internal/provider/mock/mock_provider_test.go new file mode 100644 index 0000000..d6d6347 --- /dev/null +++ b/internal/provider/mock/mock_provider_test.go @@ -0,0 +1,112 @@ +package mock + +import ( + "errors" + "testing" + + "github.com/pavelanni/storctl/internal/provider/options" + "github.com/pavelanni/storctl/internal/types" +) + +func TestMockProvider(t *testing.T) { + t.Run("CreateServer", func(t *testing.T) { + mock := &MockProvider{ + CreateServerFunc: func(opts options.ServerCreateOpts) (*types.Server, error) { + if opts.Name == "test-server" { + return &types.Server{ + ObjectMeta: types.ObjectMeta{ + Name: opts.Name, + }, + }, nil + } + return nil, errors.New("server creation failed") + }, + } + + // Test successful case + server, err := mock.CreateServer(options.ServerCreateOpts{Name: "test-server"}) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if server.Name != "test-server" { + t.Errorf("Expected server name 'test-server', got %s", server.Name) + } + + // Test error case + _, err = mock.CreateServer(options.ServerCreateOpts{Name: "invalid-server"}) + if err == nil { + t.Error("Expected error, got nil") + } + }) + + t.Run("DeleteServer", func(t *testing.T) { + mock := &MockProvider{ + DeleteServerFunc: func(name string, force bool) *types.ServerDeleteStatus { + if name == "test-server" { + return &types.ServerDeleteStatus{ + Deleted: true, + } + } + return &types.ServerDeleteStatus{ + Error: errors.New("server not found"), + } + }, + } + + // Test successful deletion + status := mock.DeleteServer("test-server", false) + if !status.Deleted { + t.Error("Expected server to be deleted") + } + if status.Error != nil { + t.Errorf("Expected no error, got %v", status.Error) + } + + // Test failed deletion + status = mock.DeleteServer("nonexistent-server", false) + if status.Deleted { + t.Error("Expected server not to be deleted") + } + if status.Error == nil { + t.Error("Expected error, got nil") + } + }) + + t.Run("GetVolume", func(t *testing.T) { + expectedVolume := &types.Volume{ + ObjectMeta: types.ObjectMeta{ + Name: "test-volume", + }, + Spec: types.VolumeSpec{ + Size: 100, + }, + } + + mock := &MockProvider{ + GetVolumeFunc: func(name string) (*types.Volume, error) { + if name == "test-volume" { + return expectedVolume, nil + } + return nil, errors.New("volume not found") + }, + } + + // Test successful case + volume, err := mock.GetVolume("test-volume") + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if volume.Name != expectedVolume.Name { + t.Errorf("Expected volume name %s, got %s", expectedVolume.Name, volume.Name) + } + if volume.Spec.Size != expectedVolume.Spec.Size { + t.Errorf("Expected volume size %d, got %d", expectedVolume.Spec.Size, volume.Spec.Size) + } + + // Test error case + _, err = mock.GetVolume("nonexistent-volume") + if err == nil { + t.Error("Expected error, got nil") + } + }) +} diff --git a/internal/provider/types.go b/internal/provider/types.go index cd67292..90d5272 100644 --- a/internal/provider/types.go +++ b/internal/provider/types.go @@ -12,7 +12,7 @@ type CloudProvider interface { ListServers(opts options.ServerListOpts) ([]*types.Server, error) AllServers() ([]*types.Server, error) DeleteServer(name string, force bool) *types.ServerDeleteStatus - + ServerToCreateOpts(server *types.Server) (options.ServerCreateOpts, error) // Volume operations CreateVolume(opts options.VolumeCreateOpts) (*types.Volume, error) GetVolume(name string) (*types.Volume, error) @@ -20,19 +20,12 @@ type CloudProvider interface { AllVolumes() ([]*types.Volume, error) DeleteVolume(name string, force bool) *types.VolumeDeleteStatus - // Lab operations - CreateLab(name string, template string) error - GetLab(name string) (*types.Lab, error) - GetLabFromCloud(name string) (*types.Lab, error) - ListLabs(opts options.LabListOpts) ([]*types.Lab, error) - DeleteLab(name string, force bool) *types.LabDeleteStatus - SyncLabs() error - // SSH Key operations CreateSSHKey(opts options.SSHKeyCreateOpts) (*types.SSHKey, error) GetSSHKey(name string) (*types.SSHKey, error) ListSSHKeys(opts options.SSHKeyListOpts) ([]*types.SSHKey, error) AllSSHKeys() ([]*types.SSHKey, error) DeleteSSHKey(name string, force bool) *types.SSHKeyDeleteStatus - KeyExists(name string) (bool, error) + CloudKeyExists(name string) (bool, error) + KeyNamesToSSHKeys(keyNames []string, opts options.SSHKeyCreateOpts) ([]*types.SSHKey, error) } diff --git a/internal/ssh/keys.go b/internal/ssh/keys.go new file mode 100644 index 0000000..bcf12b0 --- /dev/null +++ b/internal/ssh/keys.go @@ -0,0 +1,159 @@ +// Package ssh provides functions to manage local SSH keys. +package ssh + +import ( + "crypto" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "encoding/pem" + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/logger" + "golang.org/x/crypto/ssh" +) + +type Manager struct { + keysDir string + logger *slog.Logger +} + +func NewManager(cfg *config.Config) *Manager { + logLevel := logger.ParseLevel(cfg.LogLevel) + return &Manager{ + keysDir: filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, config.DefaultKeysDir), + logger: logger.NewLogger(logLevel), + } +} + +// CreateLocalKeyPair creates a local SSH key pair +// Returns the public key string in OpenSSH format. +func (m *Manager) CreateLocalKeyPair(name string) (publicKey string, err error) { + // Generate the key pair + pubKey, privKey, err := generateED25519KeyPair(name) + if err != nil { + return "", fmt.Errorf("failed to generate key pair: %w", err) + } + + // Save the keys locally + if err := os.MkdirAll(m.keysDir, 0700); err != nil { + return "", fmt.Errorf("failed to create keys directory: %w", err) + } + + // Save private key + privKeyPath := filepath.Join(m.keysDir, name) + if err := os.WriteFile(privKeyPath, privKey, 0600); err != nil { + return "", fmt.Errorf("failed to save private key: %w", err) + } + + m.logger.Debug("created local key pair", + "name", name, + "path", privKeyPath) + + return string(pubKey), nil +} + +// ReadLocalPublicKey reads a local public SSH key. +func (m *Manager) ReadLocalPublicKey(name string) (string, error) { + pubKeyPath := filepath.Join(m.keysDir, name+".pub") + pubKey, err := os.ReadFile(pubKeyPath) + if err != nil { + return "", fmt.Errorf("failed to read local public key: %w", err) + } + return string(pubKey), nil +} + +// DeleteLocalKeyPair deletes a local SSH key pair. +func (m *Manager) DeleteLocalKeyPair(name string) error { + privKeyPath := filepath.Join(m.keysDir, name) + pubKeyPath := filepath.Join(m.keysDir, name+".pub") + + m.logger.Debug("deleting local key pair", + "name", name, + "private_key", privKeyPath, + "public_key", pubKeyPath) + + if err := m.deleteKeyFile(privKeyPath); err != nil { + return err + } + if err := m.deleteKeyFile(pubKeyPath); err != nil { + return err + } + return nil +} + +// LocalKeyExists checks if a local SSH key pair exists. +func (m *Manager) LocalKeyExists(name string) (bool, error) { + privKeyPath := filepath.Join(m.keysDir, name) + pubKeyPath := filepath.Join(m.keysDir, name+".pub") + _, err := os.Stat(privKeyPath) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("failed to check local private SSH key existence: %w", err) + } + _, err = os.Stat(pubKeyPath) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("failed to check local public SSH key existence: %w", err) + } + return true, nil +} + +func (m *Manager) deleteKeyFile(path string) error { + if _, err := os.Stat(path); err == nil { + if err := os.Remove(path); err != nil { + m.logger.Error("failed to delete key file", + "path", path, + "error", err) + return fmt.Errorf("failed to delete %s: %w", path, err) + } + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to check %s existence: %w", path, err) + } + return nil +} + +// generateED25519KeyPair generates a new ED25519 keypair. +// Returns public key in OpenSSH format and private key in PEM format as byte slices. +func generateED25519KeyPair(comment string) (publicKey, privateKey []byte, err error) { + // Generate the keypair + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate ED25519 keypair: %w", err) + } + + // Convert to SSH private key format and encode as PEM + pemBlock, err := ssh.MarshalPrivateKey(crypto.PrivateKey(priv), comment) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + // Encode private key in PEM format + privateKey = pem.EncodeToMemory(pemBlock) + if privateKey == nil { + return nil, nil, fmt.Errorf("failed to encode private key") + } + + // Generate the public key + sshPub, err := ssh.NewPublicKey(pub) + if err != nil { + return nil, nil, fmt.Errorf("failed to create public key: %w", err) + } + + // Format public key in OpenSSH format: "ssh-ed25519 comment" + pubKey := fmt.Sprintf("%s %s", sshPub.Type(), + base64.StdEncoding.EncodeToString(sshPub.Marshal())) + if comment != "" { + pubKey = fmt.Sprintf("%s %s", pubKey, comment) + } + + return []byte(pubKey), privateKey, nil +} diff --git a/internal/types/errors.go b/internal/types/errors.go new file mode 100644 index 0000000..17a41db --- /dev/null +++ b/internal/types/errors.go @@ -0,0 +1,19 @@ +package types + +// ProviderError represents an error from a cloud provider +type ProviderError struct { + Code string + Message string +} + +func (e *ProviderError) Error() string { + return e.Message +} + +// NewError creates a new ProviderError +func NewError(code, message string) error { + return &ProviderError{ + Code: code, + Message: message, + } +} diff --git a/internal/types/types.go b/internal/types/types.go index 0949d0f..578e08f 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -98,7 +98,7 @@ type LabSpec struct { } type LabStatus struct { - Status string `json:"status"` + State string `json:"state"` Owner string `json:"owner"` Servers []*Server `json:"servers"` Volumes []*Volume `json:"volumes"` @@ -152,6 +152,14 @@ type SSHKeyDeleteStatus struct { Error error `json:"error"` } +type SSHKeyExistsStatus struct { + LocalExists bool `json:"localExists"` + CloudExists bool `json:"cloudExists"` + CloudExpired bool `json:"cloudExpired"` + DeleteAfter time.Time `json:"deleteAfter"` + Error error `json:"error"` +} + // Resource represents the common fields for all resources type Resource struct { TypeMeta `json:",inline"` diff --git a/internal/util/serverchecker/serverchecker.go b/internal/util/serverchecker/serverchecker.go index 3ad4ac5..fa63b93 100644 --- a/internal/util/serverchecker/serverchecker.go +++ b/internal/util/serverchecker/serverchecker.go @@ -12,7 +12,6 @@ import ( "log/slog" "github.com/pavelanni/storctl/internal/config" - "github.com/pavelanni/storctl/internal/logger" "github.com/pavelanni/storctl/internal/types" ) @@ -31,9 +30,7 @@ type ServerResult struct { Error error } -func NewServerChecker(host string, user string, keyPath string, logLevel string, timeout time.Duration, attempts int) (*ServerChecker, error) { - level := logger.ParseLevel(logLevel) - serverCheckerLogger := logger.NewLogger(level) +func NewServerChecker(host string, user string, keyPath string, logger *slog.Logger, timeout time.Duration, attempts int) (*ServerChecker, error) { // check if key exists if _, err := os.Stat(keyPath); os.IsNotExist(err) { return nil, fmt.Errorf("key file does not exist: %s", keyPath) @@ -51,11 +48,11 @@ func NewServerChecker(host string, user string, keyPath string, logLevel string, host: host, attempts: attempts, timeout: timeout, - logger: serverCheckerLogger, + logger: logger, }, nil } -func CheckServers(servers []*types.Server, logLevel string, timeout time.Duration, attempts int) ([]ServerResult, error) { +func CheckServers(servers []*types.Server, logger *slog.Logger, timeout time.Duration, attempts int) ([]ServerResult, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() @@ -66,7 +63,7 @@ func CheckServers(servers []*types.Server, logLevel string, timeout time.Duratio serverIP := server.Status.PublicNet.IPv4.IP serverPrivateKeyPath := filepath.Join(os.Getenv("HOME"), config.DefaultConfigDir, - config.KeysDir, + config.DefaultKeysDir, strings.Join([]string{server.ObjectMeta.Labels["lab_name"], "admin"}, "-")) if serverIP == "" { results[i] = ServerResult{Server: server, Error: fmt.Errorf("server IP is empty")} @@ -75,7 +72,7 @@ func CheckServers(servers []*types.Server, logLevel string, timeout time.Duratio go func(i int, server *types.Server) { defer wg.Done() - sc, err := NewServerChecker(serverIP+":22", config.DefaultAdminUser, serverPrivateKeyPath, logLevel, timeout, attempts) + sc, err := NewServerChecker(serverIP+":22", config.DefaultAdminUser, serverPrivateKeyPath, logger, timeout, attempts) if err != nil { results[i] = ServerResult{Server: server, Error: err} return diff --git a/internal/util/serverchecker/serverchecker_test.go b/internal/util/serverchecker/serverchecker_test.go index 5e8da49..0f0882c 100644 --- a/internal/util/serverchecker/serverchecker_test.go +++ b/internal/util/serverchecker/serverchecker_test.go @@ -9,19 +9,21 @@ import ( "time" "github.com/pavelanni/storctl/internal/config" + "github.com/pavelanni/storctl/internal/logger" "github.com/pavelanni/storctl/internal/types" ) func TestNewServerChecker(t *testing.T) { + logger := logger.NewLogger(slog.LevelDebug) // Test creation with invalid key path - _, err := NewServerChecker("localhost:22", config.DefaultAdminUser, "/nonexistent/key", "debug", 1*time.Minute, 1) + _, err := NewServerChecker("localhost:22", config.DefaultAdminUser, "/nonexistent/key", logger, 1*time.Minute, 1) if err == nil { t.Error("Expected error for nonexistent key, got nil") } // Test creation with valid parameters (you'll need to provide a real test key) // TODO: Add path to a test SSH key - checker, err := NewServerChecker("localhost:22", "testuser", "testdata/test_key", "debug", 1*time.Minute, 1) + checker, err := NewServerChecker("localhost:22", "testuser", "testdata/test_key", logger, 1*time.Minute, 1) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -33,6 +35,7 @@ func TestNewServerChecker(t *testing.T) { func TestCheckServers(t *testing.T) { t.Parallel() + logger := logger.NewLogger(slog.LevelDebug) // Create test servers with mock IPs servers := []*types.Server{ { @@ -68,7 +71,7 @@ func TestCheckServers(t *testing.T) { } // Use shorter timeout for tests - results, err := CheckServers(servers, "debug", 100*time.Millisecond, 2) + results, err := CheckServers(servers, logger, 100*time.Millisecond, 2) t.Logf("results: %+v, err: %+v", results, err) // Expect error because these are not real servers @@ -94,13 +97,13 @@ func TestCheckServers(t *testing.T) { // Add a new test specifically for checkServerReady func TestServerChecker_checkServerReady(t *testing.T) { t.Parallel() - + logger := logger.NewLogger(slog.LevelDebug) // Create a ServerChecker with shorter intervals for testing sc, err := NewServerChecker( "192.0.2.1:22", // non-routable IP "testuser", "testdata/test_key", - "debug", + logger, 5*time.Second, // total timeout 3, // number of attempts )