-
Notifications
You must be signed in to change notification settings - Fork 21
/
connection.go
132 lines (113 loc) · 3.52 KB
/
connection.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package ts3
import (
"errors"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
const (
// DefaultPort is the default TeamSpeak 3 ServerQuery port.
DefaultPort = 10011
// DefaultSSHPort is the default TeamSpeak 3 ServerQuery SSH port.
DefaultSSHPort = 10022
)
// legacyConnection is an insecure TCP connection.
type legacyConnection struct {
net.Conn
}
// Connect connects to the address with the given timeout.
func (c *legacyConnection) Connect(addr string, timeout time.Duration) error {
addr, err := verifyAddr(addr, DefaultPort)
if err != nil {
return err
}
c.Conn, err = net.DialTimeout("tcp", addr, timeout)
if err != nil {
return fmt.Errorf("legacy connection: dial: %w", err)
}
return nil
}
// sshConnection is an SSH connection with open SSH channel and attached shell.
type sshConnection struct {
net.Conn
config *ssh.ClientConfig
channel ssh.Channel
}
// Connect connects to the address with the given timeout and opens a new SSH channel with attached shell.
func (c *sshConnection) Connect(addr string, timeout time.Duration) error {
addr, err := verifyAddr(addr, DefaultSSHPort)
if err != nil {
return err
}
if c.Conn, err = net.DialTimeout("tcp", addr, timeout); err != nil {
return fmt.Errorf("ssh connection: dial: %w", err)
}
clientConn, chans, reqs, err := ssh.NewClientConn(c.Conn, addr, c.config)
if err != nil {
return fmt.Errorf("ssh connecion: ssh client conn: %w", err)
}
go ssh.DiscardRequests(reqs)
// Reject all channel requests.
go func(newChannel <-chan ssh.NewChannel) {
for channel := range newChannel {
channel.Reject(ssh.Prohibited, ssh.Prohibited.String()) //nolint: errcheck
}
}(chans)
c.channel, reqs, err = clientConn.OpenChannel("session", nil)
if err != nil {
return fmt.Errorf("ssh connection: session: %w", err)
}
go ssh.DiscardRequests(reqs)
ok, err := c.channel.SendRequest("shell", true, nil)
if err != nil {
return fmt.Errorf("ssh connection: shell: %w", err)
}
if !ok {
return fmt.Errorf("ssh connection: could not open shell")
}
return nil
}
// Read implements io.Reader.
func (c *sshConnection) Read(p []byte) (n int, err error) {
// Don't wrap as it needs to return raw EOF as per https://pkg.go.dev/io#Reader
return c.channel.Read(p) //nolint: wrapcheck
}
// Write implements io.Writer.
func (c *sshConnection) Write(p []byte) (n int, err error) {
return c.channel.Write(p) //nolint: wrapcheck
}
// Close implements io.Closer.
func (c *sshConnection) Close() error {
var err error
// In both cases we ignore errors which don't have any value.
if err2 := c.channel.Close(); err2 != nil &&
!errors.Is(err2, io.EOF) &&
!strings.HasSuffix(err2.Error(), "connection reset by peer") {
err = err2
}
if err2 := c.Conn.Close(); err2 != nil &&
err == nil &&
!errors.Is(err2, net.ErrClosed) &&
!strings.HasSuffix(err2.Error(), "connection reset by peer") {
err = err2
}
return err
}
// verifyAddr checks if addr is formatted correctly. If valid it returns addr.
// If the address does not include a port, defaultPort is added.
// A literal IPv6 must be enclosed in square brackets e.g. "[::1]".
func verifyAddr(addr string, defaultPort int) (string, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
var addrError *net.AddrError
if ok := errors.As(err, &addrError); ok && addrError.Err == "missing port in address" {
return net.JoinHostPort(strings.Trim(addr, "[]"), strconv.Itoa(defaultPort)), nil
}
return "", fmt.Errorf("verify address: %w", err)
}
return net.JoinHostPort(host, port), nil
}