forked from moul/sshportal
-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
158 lines (139 loc) · 3.95 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package main
import (
"fmt"
"log"
"math"
"net"
"os"
"time"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/gliderlabs/ssh"
"github.com/urfave/cli"
gossh "golang.org/x/crypto/ssh"
"moul.io/sshportal/pkg/bastion"
)
type serverConfig struct {
aesKey string
dbDriver, dbURL string
logsLocation string
bindAddr string
debug, demo bool
idleTimeout time.Duration
aclCheckCmd string
}
func parseServerConfig(c *cli.Context) (*serverConfig, error) {
ret := &serverConfig{
aesKey: c.String("aes-key"),
dbDriver: c.String("db-driver"),
dbURL: c.String("db-conn"),
bindAddr: c.String("bind-address"),
debug: c.Bool("debug"),
demo: c.Bool("demo"),
logsLocation: c.String("logs-location"),
idleTimeout: c.Duration("idle-timeout"),
aclCheckCmd: c.String("acl-check-cmd"),
}
switch len(ret.aesKey) {
case 0, 16, 24, 32:
default:
return nil, fmt.Errorf("invalid aes key size, should be 16 or 24, 32")
}
return ret, nil
}
func ensureLogDirectory(location string) error {
// check for the logdir existence
logsLocation, err := os.Stat(location)
if err != nil {
if os.IsNotExist(err) {
return os.MkdirAll(location, os.ModeDir|os.FileMode(0750))
}
return err
}
if !logsLocation.IsDir() {
return fmt.Errorf("log directory cannot be created")
}
return nil
}
func dbConnect(c *serverConfig, config gorm.Option) (*gorm.DB, error) {
var dbOpen func(string) gorm.Dialector
if c.dbDriver == "sqlite3" {
dbOpen = sqlite.Open
}
if c.dbDriver == "postgres" {
dbOpen = postgres.Open
}
if c.dbDriver == "mysql" {
dbOpen = mysql.Open
}
return gorm.Open(dbOpen(c.dbURL), config)
}
func server(c *serverConfig) (err error) {
// configure db logging
db, err := dbConnect(c, &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
sqlDB, err := db.DB()
defer func() {
origErr := err
err = sqlDB.Close()
if origErr != nil {
err = origErr
}
}()
if err = sqlDB.Ping(); err != nil {
return
}
if err = bastion.DBInit(db); err != nil {
return
}
// create TCP listening socket
ln, err := net.Listen("tcp", c.bindAddr)
if err != nil {
return err
}
// configure server
srv := &ssh.Server{
Addr: c.bindAddr,
Handler: func(s ssh.Session) { bastion.ShellHandler(s, GitTag, GitSha, GitTag) }, // ssh.Server.Handler is the handler for the DefaultSessionHandler
Version: fmt.Sprintf("sshportal-%s", GitTag),
ChannelHandlers: map[string]ssh.ChannelHandler{
"default": bastion.ChannelHandler,
},
}
// configure channel handler
bastion.DefaultChannelHandler = func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
switch newChan.ChannelType() {
case "session":
go ssh.DefaultSessionHandler(srv, conn, newChan, ctx)
case "direct-tcpip":
go ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
default:
if err := newChan.Reject(gossh.UnknownChannelType, "unsupported channel type"); err != nil {
log.Printf("failed to reject chan: %v", err)
}
}
}
if c.idleTimeout != 0 {
srv.IdleTimeout = c.idleTimeout
// gliderlabs/ssh requires MaxTimeout to be non-zero if we want to use IdleTimeout.
// So, set it to the max value, because we don't want a max timeout.
srv.MaxTimeout = math.MaxInt64
}
for _, opt := range []ssh.Option{
// custom PublicKeyAuth handler
ssh.PublicKeyAuth(bastion.PublicKeyAuthHandler(db, c.logsLocation, c.aclCheckCmd, c.aesKey, c.dbDriver, c.dbURL, c.bindAddr, c.demo)),
ssh.PasswordAuth(bastion.PasswordAuthHandler(db, c.logsLocation, c.aclCheckCmd, c.aesKey, c.dbDriver, c.dbURL, c.bindAddr, c.demo)),
// retrieve sshportal SSH private key from database
bastion.PrivateKeyFromDB(db, c.aesKey),
} {
if err := srv.SetOption(opt); err != nil {
return err
}
}
log.Printf("info: SSH Server accepting connections on %s, idle-timout=%v", c.bindAddr, c.idleTimeout)
return srv.Serve(ln)
}