diff --git a/pkg/mockssh/server.go b/pkg/mockssh/server.go index e140ec9..050e92b 100644 --- a/pkg/mockssh/server.go +++ b/pkg/mockssh/server.go @@ -12,7 +12,6 @@ import ( "io" "net" "net/http" - "os" "os/exec" "sync" "testing" @@ -31,9 +30,9 @@ type Server struct { CertAuthorityKeys []ssh.PublicKey CertChecker ssh.CertChecker - // RemoteEnv, RemoteDir and CommandHandler are optional configuration. - RemoteEnv []string - RemoteDir string + // An optional CommandHandler, which responds to commands sent over SSH. + // NewServer will give this a default using ExecHandler, which can also + // be reused from custom handlers. CommandHandler CommandHandler // listener and port are set after Start. @@ -47,7 +46,7 @@ type CommandIO struct { StdErr io.Writer } -type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) int +type CommandHandler func(conn ssh.ConnMetadata, command string, commandIO CommandIO) int // NewServer creates and starts a local SSH server for a test. // It must be stopped with the Server.Stop method. @@ -65,9 +64,8 @@ func NewServer(t *testing.T, authorityEndpoint string) (*Server, error) { } s := &Server{t: t, hostKey: hk} - s.CommandHandler = s.defaultCommandHandler + s.CommandHandler = ExecHandler("", nil) s.CertChecker = s.defaultCertChecker() - s.RemoteDir = t.TempDir() s.CertAuthorityKeys = keys if err := s.start(); err != nil { @@ -89,6 +87,10 @@ func (s *Server) HostKeyConfig() string { ) } +func (s *Server) HostKey() ssh.PublicKey { + return s.hostKey.PublicKey() +} + func (s *Server) start() error { t := s.t @@ -148,22 +150,25 @@ func (s *Server) Stop() error { return nil } -func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, commandIO CommandIO) int { - c := exec.Command("bash", "-c", command) - c.Stdout = commandIO.StdOut - c.Stderr = commandIO.StdErr - c.Stdin = commandIO.StdIn - c.Dir = s.RemoteDir - c.Env = append(os.Environ(), s.RemoteEnv...) - if err := c.Run(); err != nil { - exitErr := &exec.ExitError{} - if errors.As(err, &exitErr) { - return exitErr.ExitCode() +// ExecHandler returns a CommandHandler to execute a command in the given environment. +func ExecHandler(workingDir string, env []string) CommandHandler { + return func(_ ssh.ConnMetadata, command string, commandIO CommandIO) int { + c := exec.Command("bash", "-c", command) + c.Stdout = commandIO.StdOut + c.Stderr = commandIO.StdErr + c.Stdin = commandIO.StdIn + c.Dir = workingDir + c.Env = env + if err := c.Run(); err != nil { + exitErr := &exec.ExitError{} + if errors.As(err, &exitErr) { + return exitErr.ExitCode() + } + _, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err) + return 1 } - _, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err) - return 1 + return 0 } - return 0 } func (s *Server) defaultCertChecker() ssh.CertChecker { @@ -253,9 +258,9 @@ func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChann for { select { case s := <-exitWithStatus: - _, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status int }{s})) + _, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{uint32(s)})) //nolint: gosec if err != nil { - t.Errorf("Failed to send exit status: %v", err) + t.Fatalf("Failed to send exit status: %v", err) } goto closeChannel case <-timer.C: diff --git a/pkg/mockssh/server_test.go b/pkg/mockssh/server_test.go new file mode 100644 index 0000000..2b6d068 --- /dev/null +++ b/pkg/mockssh/server_test.go @@ -0,0 +1,105 @@ +package mockssh_test + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "fmt" + "net" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/platformsh/cli/pkg/mockapi" + "github.com/platformsh/cli/pkg/mockssh" +) + +func TestServer(t *testing.T) { + authServer := mockapi.NewAuthServer(t) + defer authServer.Close() + + sshServer, err := mockssh.NewServer(t, authServer.URL+"/ssh/authority") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _ = sshServer.Stop() + }) + + tempDir := t.TempDir() + sshServer.CommandHandler = mockssh.ExecHandler(tempDir, []string{}) + + cert := getTestSSHAuth(t, authServer.URL) + + // Create the SSH client configuration + address := fmt.Sprintf("127.0.0.1:%d", sshServer.Port()) + config := &ssh.ClientConfig{ + User: "test", + Auth: []ssh.AuthMethod{ssh.PublicKeys(cert)}, + HostKeyCallback: func(_ string, remote net.Addr, key ssh.PublicKey) error { + if remote.String() != address { + return fmt.Errorf("unexpected address: %s", remote.String()) + } + if bytes.Equal(sshServer.HostKey().Marshal(), key.Marshal()) { + return nil + } + return fmt.Errorf("host key mismatch") + }, + } + + client, err := ssh.Dial("tcp", address, config) + require.NoError(t, err) + defer client.Close() + + session, err := client.NewSession() + require.NoError(t, err) + defer session.Close() + + stdOutBuffer := &bytes.Buffer{} + session.Stdout = stdOutBuffer + + require.NoError(t, session.Run("pwd")) + assert.Equal(t, tempDir, strings.TrimRight(stdOutBuffer.String(), "\n")) + + session2, err := client.NewSession() + require.NoError(t, err) + defer session2.Close() + err = session2.Run("false") + assert.Error(t, err) + var exitErr *ssh.ExitError + assert.ErrorAs(t, err, &exitErr) + assert.Equal(t, 1, exitErr.ExitStatus()) +} + +func getTestSSHAuth(t *testing.T, authServerURL string) ssh.Signer { + t.Helper() + + // Generate a keypair + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + s, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + + b, err := json.Marshal(struct{ Key string }{string(ssh.MarshalAuthorizedKey(s.PublicKey()))}) + require.NoError(t, err) + resp, err := http.DefaultClient.Post(authServerURL+"/ssh", "application/json", bytes.NewReader(b)) + require.NoError(t, err) + defer resp.Body.Close() + + var rs struct{ Certificate string } + require.NoError(t, json.NewDecoder(resp.Body).Decode(&rs)) + + parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(rs.Certificate)) //nolint: dogsled + require.NoError(t, err) + + cert, _ := parsed.(*ssh.Certificate) + certSigner, err := ssh.NewCertSigner(cert, s) + require.NoError(t, err) + + return certSigner +}