diff --git a/command/ssh.go b/command/ssh.go index 5a3e0aa2ce91..986228d525ec 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -243,17 +243,27 @@ func (c *SSHCommand) Run(args []string) int { return 1 } - // Extract the username and IP. - username, hostname, ip, err := c.userHostAndIP(args[0]) + // Extract the hostname, username and port from the ssh command + hostname, username, port, err := c.parseSSHCommand(args) if err != nil { - c.UI.Error(fmt.Sprintf("Error parsing user and IP: %s", err)) + c.UI.Error(fmt.Sprintf("Error parsing the ssh command: %q", err)) return 1 } - // The rest of the args are ssh args - sshArgs := []string{} - if len(args) > 1 { - sshArgs = args[1:] + // Use the current user if no user was specified in the ssh command + if username == "" { + u, err := user.Current() + if err != nil { + c.UI.Error(fmt.Sprintf("Error getting the current user: %q", err)) + return 1 + } + username = u.Username + } + + ip, err := c.resolveHostname(hostname) + if err != nil { + c.UI.Error(fmt.Sprintf("Error resolving the ssh hostname: %q", err)) + return 1 } // Set the client in the command @@ -329,11 +339,11 @@ func (c *SSHCommand) Run(args []string) int { switch strings.ToLower(c.flagMode) { case ssh.KeyTypeCA: - return c.handleTypeCA(username, hostname, ip, sshArgs) + return c.handleTypeCA(username, ip, port, args) case ssh.KeyTypeOTP: - return c.handleTypeOTP(username, hostname, ip, sshArgs) + return c.handleTypeOTP(username, ip, port, args) case ssh.KeyTypeDynamic: - return c.handleTypeDynamic(username, ip, sshArgs) + return c.handleTypeDynamic(username, ip, port, args) default: c.UI.Error(fmt.Sprintf("Unknown SSH mode: %s", c.flagMode)) return 1 @@ -341,7 +351,7 @@ func (c *SSHCommand) Run(args []string) int { } // handleTypeCA is used to handle SSH logins using the "CA" key type. -func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []string) int { +func (c *SSHCommand) handleTypeCA(username, ip, port string, sshArgs []string) int { // Read the key from disk publicKey, err := ioutil.ReadFile(c.flagPublicKeyPath) if err != nil { @@ -460,10 +470,6 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin ) } - args = append(args, - username+"@"+hostname, - ) - // Add extra user defined ssh arguments args = append(args, sshArgs...) @@ -493,7 +499,7 @@ func (c *SSHCommand) handleTypeCA(username, hostname, ip string, sshArgs []strin } // handleTypeOTP is used to handle SSH logins using the "otp" key type. -func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs []string) int { +func (c *SSHCommand) handleTypeOTP(username, ip, port string, sshArgs []string) int { secret, cred, err := c.generateCredential(username, ip) if err != nil { c.UI.Error(fmt.Sprintf("failed to generate credential: %s", err)) @@ -543,10 +549,13 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs ) } + // If a port wasn't specified in the ssh arguments lets use the port we got back from vault + if port == "" { + args = append(args, "-p", cred.Port) + } + args = append(args, "-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking, - "-p", cred.Port, - username+"@"+hostname, ) // Add the rest of the ssh args appended by the user @@ -585,7 +594,7 @@ func (c *SSHCommand) handleTypeOTP(username, hostname string, ip string, sshArgs } // handleTypeDynamic is used to handle SSH logins using the "dyanmic" key type. -func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) int { +func (c *SSHCommand) handleTypeDynamic(username, ip, port string, sshArgs []string) int { // Generate the credential secret, cred, err := c.generateCredential(username, ip) if err != nil { @@ -610,13 +619,20 @@ func (c *SSHCommand) handleTypeDynamic(username, ip string, sshArgs []string) in return 1 } - args := append([]string{ + args := make([]string, 0) + // If a port wasn't specified in the ssh arguments lets use the port we got back from vault + if port == "" { + args = append(args, "-p", cred.Port) + } + + args = append(args, "-i", keyPath, - "-o UserKnownHostsFile=" + c.flagUserKnownHostsFile, - "-o StrictHostKeyChecking=" + c.flagStrictHostKeyChecking, - "-p", cred.Port, - username + "@" + ip, - }, sshArgs...) + "-o UserKnownHostsFile="+c.flagUserKnownHostsFile, + "-o StrictHostKeyChecking="+c.flagStrictHostKeyChecking, + ) + + // Add extra user defined ssh arguments + args = append(args, sshArgs...) cmd := exec.Command("ssh", args...) cmd.Stdin = os.Stdin @@ -745,37 +761,95 @@ func (c *SSHCommand) defaultRole(mountPoint, ip string) (string, error) { } } -// userAndIP takes an argument in the format foo@1.2.3.4 and separates the IP -// and user parts, returning any errors. -func (c *SSHCommand) userHostAndIP(s string) (string, string, string, error) { - // split the parameter username@ip - input := strings.Split(s, "@") - var username, address string - - // If only IP is mentioned and username is skipped, assume username to - // be the current username. Vault SSH role's default username could have - // been used, but in order to retain the consistency with SSH command, - // current username is employed. - switch len(input) { - case 1: - u, err := user.Current() - if err != nil { - return "", "", "", errors.Wrap(err, "failed to fetch current user") +// Finds the hostname, username (optional) and port (optional) from any valid ssh command +// Supports usrname@hostname but also specifying valid ssh flags like -o User=username, +// -o Port=2222 and -p 2222 anywhere in the command +func (c *SSHCommand) parseSSHCommand(args []string) (hostname string, username string, port string, err error) { + lastArg := "" + + for _, i := range args { + arg := lastArg + lastArg = "" + + // If -p has been specified then this is our ssh port + if arg == "-p" { + port = i + continue } - username, address = u.Username, input[0] - case 2: - username, address = input[0], input[1] - default: - return "", "", "", fmt.Errorf("invalid arguments: %q", s) + + // this is an ssh option, lets see if User or Port have been set and use it + if arg == "-o" { + split := strings.Split(i, "=") + key := split[0] + // Incase the value contains = signs we want to get all of them + value := strings.Join(split[1:], " ") + + if key == "User" { + // Don't overwrite the user if it is already set by username@hostname + // This matches the behaviour for how regular ssh reponds when both are specified + if username == "" { + username = value + } + } + + if key == "Port" { + // Don't overwrite the port if it is already set by -p + // This matches the behaviour for how regular ssh reponds when both are specified + if port == "" { + port = value + } + } + continue + } + + // This isn't an ssh argument that we care about. Lets keep on parsing the command + if arg != "" { + continue + } + + // If this is an ssh argument we want to look at the value + if strings.HasPrefix(i, "-") { + lastArg = i + continue + } + + // If we have gotten this far it means this is a bare argument + // The first bare argument is the hostname + // The second bare argument is the command to run on the remote host + + // If the hostname hasn't been set yet than it means we have found the first bare argument + if hostname == "" { + if strings.Contains(i, "@") { + split := strings.Split(i, "@") + username = split[0] + hostname = split[1] + } else { + hostname = i + } + continue + } else { + // The second bare argument is the command to run on the remote host. + // We need to break out and stop parsing arugments now + break + } + + } + if hostname == "" { + return "", "", "", errors.Wrap( + err, + fmt.Sprintf("failed to find a hostname in ssh command %q", strings.Join(args, " ")), + ) } + return hostname, username, port, nil +} +func (c *SSHCommand) resolveHostname(hostname string) (ip string, err error) { // Resolving domain names to IP address on the client side. // Vault only deals with IP addresses. - ipAddr, err := net.ResolveIPAddr("ip", address) + ipAddr, err := net.ResolveIPAddr("ip", hostname) if err != nil { - return "", "", "", errors.Wrap(err, "failed to resolve IP address") + return "", errors.Wrap(err, "failed to resolve IP address") } - ip := ipAddr.String() - - return username, address, ip, nil + ip = ipAddr.String() + return ip, nil } diff --git a/command/ssh_test.go b/command/ssh_test.go index 189ea2887f23..3ed4a9ea93b1 100644 --- a/command/ssh_test.go +++ b/command/ssh_test.go @@ -21,3 +21,136 @@ func TestSSHCommand_Run(t *testing.T) { t.Parallel() t.Skip("Need a way to setup target infrastructure") } + +func TestParseSSHCommand(t *testing.T) { + t.Parallel() + + _, cmd := testSSHCommand(t) + var tests = []struct { + name string + args []string + hostname string + username string + port string + err error + }{ + { + "Parse just a hostname", + []string{ + "hostname", + }, + "hostname", + "", + "", + nil, + }, + { + "Parse the standard username@hostname", + []string{ + "username@hostname", + }, + "hostname", + "username", + "", + nil, + }, + { + "Parse the username out of -o User=username", + []string{ + "-o", "User=username", + "hostname", + }, + "hostname", + "username", + "", + nil, + }, + { + "If the username is specified with -o User=username and realname@hostname prefer realname@", + []string{ + "-o", "User=username", + "realname@hostname", + }, + "hostname", + "realname", + "", + nil, + }, + { + "Parse the port out of -o Port=2222", + []string{ + "-o", "Port=2222", + "hostname", + }, + "hostname", + "", + "2222", + nil, + }, + { + "Parse the port out of -p 2222", + []string{ + "-p", "2222", + "hostname", + }, + "hostname", + "", + "2222", + nil, + }, + { + "If port is defined with -o Port=2222 and -p 2244 prefer -p", + []string{ + "-p", "2244", + "-o", "Port=2222", + "hostname", + }, + "hostname", + "", + "2244", + nil, + }, + { + "Ssh args with a command", + []string{ + "hostname", + "command", + }, + "hostname", + "", + "", + nil, + }, + { + "Flags after the ssh command are not pased because they are part of the command", + []string{ + "username@hostname", + "command", + "-p 22", + }, + "hostname", + "username", + "", + nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + hostname, username, port, err := cmd.parseSSHCommand(test.args) + if err != test.err { + t.Errorf("got error: %q want %q", err, test.err) + } + if hostname != test.hostname { + t.Errorf("got hostname: %q want %q", hostname, test.hostname) + } + if username != test.username { + t.Errorf("got username: %q want %q", username, test.username) + } + if port != test.port { + t.Errorf("got port: %q want %q", port, test.port) + } + }) + } +}