forked from mk-fg/fgtk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssh-reverse-mux-server
executable file
·244 lines (201 loc) · 9.05 KB
/
ssh-reverse-mux-server
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#!/usr/bin/env python3
import os, sys, logging, contextlib, asyncio, socket, signal
import math, hashlib, secrets, struct, shelve, base64
class SSHMuxConfig:
auth_secret = 'set this via cli option!'
mux_attempts = 4
mux_port = 8739
mux_timeout = 5.0
ssh_port = 22
tunnel_port_range = '22000:22100'
ident_db_path = 'ssh-reverse-mux-ident.db'
class LogMessage:
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), **log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
to_bytes = lambda s: s if isinstance(s, bytes) else str(s).encode()
class MuxServerProtocol:
transport = None
def __init__(self, loop):
self.requests = asyncio.Queue(loop=loop)
self.log = get_logger('mux-server.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
self.requests.put_nowait((data, addr))
# def error_received(self, err):
# self.log.debug('Network error: {}', err)
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err)
self.requests.put_nowait(None)
class AuthError(Exception): pass
def parse_req(secret, req):
try:
ident_len = req[0]
ident, salt, mac = req[1:ident_len+1], req[ident_len+1:ident_len+17], req[ident_len+17:]
if len(ident) != ident_len or not salt or not mac: raise AuthError('Invalid structure')
mac_chk = hashlib.blake2b(ident, key=to_bytes(secret), salt=salt).digest()
if not secrets.compare_digest(mac, mac_chk): raise AuthError('MAC mismatch')
except AuthError as err:
log.debug('Failed to parse/auth request value: {}', err)
return
return base64.urlsafe_b64encode(ident).decode()
def build_res(secret, ident, tun_port, ssh_port):
res = struct.pack('>HH', ssh_port, tun_port)
salt = os.urandom(16)
res_chk = to_bytes(ident) + res
mac = hashlib.blake2b(res_chk, key=to_bytes(secret), salt=salt).digest()
return struct.pack('>B', len(res)) + res + salt + mac
def ident_repr(ident):
try: ident_t, ident_dec = 'str', base64.urlsafe_b64decode(ident).decode()
except UnicodeDecodeError: ident_t, ident_dec = 'b64', ident
return f'[{ident_t}] {ident_dec!r}'
async def mux_send(loop, transport, response, addr, delays):
for delay in delays:
transport.sendto(response, addr)
await asyncio.sleep(delay, loop=loop)
async def mux_listen( loop, secret, ident_db,
sock_af, sock_p, host, port, tun_port_a, tun_port_b, ssh_port, delays ):
def tun_ports_iter_func():
tun_ports_used = set(ident_db.values())
for port in range(tun_port_a, tun_port_b + 1):
if port not in tun_ports_used: yield port
tun_ports_iter = tun_ports_iter_func()
responses = dict()
transport, proto = await loop.create_datagram_endpoint(
lambda: MuxServerProtocol(loop), local_addr=(host, port), family=sock_af, proto=sock_p )
try:
while True:
req, addr = await proto.requests.get()
ident = parse_req(secret, req)
if not ident: continue
if ident in responses:
if not responses[ident].done(): continue
await responses[ident]
tun_port = ident_db.get(ident)
if not tun_port or not tun_port_a <= tun_port <= tun_port_b:
try: tun_port = next(tun_ports_iter)
except StopIteration:
log.debug( 'No more ports to allocate'
' ident: {} (addr={})', ident_repr(ident), addr )
continue
ident_db[ident] = tun_port
ident_db.sync()
response = build_res(secret, ident, tun_port, ssh_port)
log.debug(
'Allocated [tun={}, ssh={}] for ident: {} (addr={})',
tun_port, ssh_port, ident_repr(ident), addr )
responses[ident] = loop.create_task(
mux_send(loop, transport, response, addr, delays) )
finally:
for response in responses.values(): await response
transport.close()
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
def main(args=None, conf=None):
if not conf: conf = SSHMuxConfig()
import argparse
parser = argparse.ArgumentParser(
description='Multiplexer for "ssh -R" connections,'
' directing each one to unique port(s) according to provided ident-sting.')
parser.add_argument('bind', nargs='?', default='::',
help='Host or address (to be resolved via gai) to listen on.'
' Default is to use "::" wildcard IPv4/IPv6 binding.')
parser.add_argument('-s', '--auth-secret',
default=conf.auth_secret, metavar='string',
help='Any string to use as symmetric secret'
' to authenticate both sides on --mux-port (default: %(default)s).'
' Must be same for both ssh-reverse-mux-client and server scripts talking to each other.')
parser.add_argument('-i', '--ident-db',
default=conf.ident_db_path, metavar='path',
help='Path to db to store all the seen clients to, for persistent port allocation.'
' Default: %(default)s')
parser.add_argument('-l', '--ident-list',
action='store_true', help='List stored ident-port mappings and exit.')
parser.add_argument('-m', '--mux-port',
default=conf.mux_port, type=int, metavar='port',
help='Local UDP port to listen on for muxer requests from clients (default: %(default)s).'
' Can also be specified in the "bind" argument, which overrides this option.')
parser.add_argument('-p', '--ssh-port',
type=int, metavar='port', default=conf.ssh_port,
help='Local sshd port to send to clients (default: %(default)s).')
parser.add_argument('-r', '--tunnel-port-range',
metavar='port_from:port_to', default=conf.tunnel_port_range,
help='Range in which to allocate'
' reverse-tunnel listening ports, inclusive. Default: %(default)s')
parser.add_argument('-n', '--attempts',
type=int, metavar='n', default=conf.mux_attempts,
help='Number of UDP packets to send from'
' --mux-port in response to clients (to offset packet loss). Default: %(default)s')
parser.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=conf.mux_timeout,
help='Negotiation response timeout on --mux-port, in seconds. Default: %(default)ss')
parser.add_argument('-d', '--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
global log
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
log = get_logger('mux-server.main')
ident_db = shelve.open(opts.ident_db, 'c')
if opts.ident_list:
for ident, tun_port in ident_db.items():
print('port {} :: {}'.format(tun_port, ident_repr(ident)))
return
host, port_mux = opts.bind, opts.mux_port
try:
addrinfo = socket.getaddrinfo(
host, str(port_mux), type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror('No addrinfo for host: {}'.format(host))
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters (address, family)'
' via getaddrinfo: {!r} - [{}] {}'.format((host, port_mux), err.__class__.__name__, err) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
log.debug(
'Resolved mux host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port_mux, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
host, port_mux = sock_addr[:2]
try:
tun_port_a, tun_port_b = map(int, opts.tunnel_port_range.split(':', 1))
for p in tun_port_a, tun_port_b:
if not 0 < p < 65535: raise ValueError(f'Out of range: {p!r}')
if tun_port_a > tun_port_b: raise ValueError(tun_port_a, tun_port_b)
except Exception as err:
parser.error(f'Failed to parse tunnel port range: [{err.__class__.__name__}] {err}')
log.debug(
'Parsed tunnel port range: {}-{} ({} port(s))',
tun_port_a, tun_port_b, tun_port_b - tun_port_a + 1 )
retry_delays = retries_within_timeout(opts.attempts, opts.timeout)
with contextlib.closing(asyncio.get_event_loop()) as loop:
muxer = loop.create_task(mux_listen( loop, opts.auth_secret, ident_db,
sock_af, sock_p, host, port_mux, tun_port_a, tun_port_b, opts.ssh_port, retry_delays ))
for sig in 'INT TERM'.split():
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), muxer.cancel)
try: return loop.run_until_complete(muxer)
except asyncio.CancelledError: return
if __name__ == '__main__': sys.exit(main())