forked from hashicorp/dawdle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdawdle.go
452 lines (385 loc) · 10.4 KB
/
dawdle.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
// Package dawdle provides a simple proxy for testing network
// connections, offering various facilities to introduce unfavorable
// network conditions.
//
// As the package is designed for use in testing, a large amount of
// functionality is exported. It's recommended that you use the most
// amount of composition that makes sense for you, with the intention
// being that if you need access the more lower-level parts of the
// package, they are available to you.
package dawdle
import (
"errors"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
)
const defaultBufferSize = 32 * 1024
// ErrNewProxy denotes an error in proxy creation.
func ErrNewProxy(err error) error {
return fmt.Errorf("error creating proxy: %w", err)
}
// ErrProxyListener denotes an error starting the proxy listener.
func ErrProxyListener(err error) error {
return fmt.Errorf("error starting listener: %w", err)
}
// ErrProxyRun denotes an error running the proxy.
func ErrProxyRun(err error) error {
return fmt.Errorf("error running proxy: %w", err)
}
// ErrProxyHandleRemoteConnect denotes an error making the connection
// to the remote.
func ErrProxyHandleRemoteConnect(err error) error {
return fmt.Errorf("error connecting to remote: %w", err)
}
// ErrProxyHandleStream denotes an error reading the connection.
func ErrProxyHandleStream(err error) error {
return fmt.Errorf("error in network stream: %w", err)
}
// ErrProxyHandleStream denotes an error reading the connection.
func ErrProxyHandleCloseListener(err error) error {
return fmt.Errorf("error closing listener: %w", err)
}
// ErrProxyClose denotes an error on general close. Errors are not
// wrapped.
func ErrProxyClose(errs []error) error {
b := new(strings.Builder)
for _, e := range errs {
b.WriteString(e.Error())
b.WriteRune('\n')
}
return errors.New(b.String())
}
// ErrProxyCloseListener denotes an error closing the listener.
func ErrProxyCloseListener(err error) error {
return fmt.Errorf("error closing listener: %w", err)
}
// ErrProxyCloseConnections denotes an error closing the listener.
// Individual errors are not wrapped.
func ErrProxyCloseConnections(errs []error) error {
msgs := make([]string, len(errs))
for i, e := range errs {
msgs[i] = e.Error()
}
return fmt.Errorf("error closing connections:\n\t%s", strings.Join(msgs, "\n\t"))
}
// ProxyOption are options designed to control the behavior of the
// proxy.
type ProxyOption func(p *proxy) error
// WithRbufSize supplies the read buffer size for proxied
// connections. A size of less than 1 means the default size will be
// used (32k).
func WithRbufSize(size int) func(p *proxy) error {
return func(p *proxy) error {
if size < 1 {
size = defaultBufferSize
}
p.rbufSize = size
return nil
}
}
// WithWbufSize supplies the write buffer size for proxied
// connections. A size of less than 1 means the default size will be
// used (32k).
func WithWbufSize(size int) func(p *proxy) error {
return func(p *proxy) error {
if size < 1 {
size = defaultBufferSize
}
p.wbufSize = size
return nil
}
}
// WithLogger sets a log for writing deep errors and other debugging
// data to.
func WithLogger(logger *log.Logger) func(p *proxy) error {
return func(p *proxy) error {
p.logger = logger
return nil
}
}
// WithListener allows an existing listener to be passed in as the
// local connection.
//
// Note that if this is passed in, localAddr in NewProxy is ignored,
// and the server is started immediately.
//
// The protocol of the listener needs to match the protocol passed
// into NewProxy. Only TCP listeners are allowed.
func WithListener(ln net.Listener) func(p *proxy) error {
return func(p *proxy) error {
switch ln.(type) {
case *net.TCPListener:
if p.proto != "tcp" {
return fmt.Errorf("listener type mismatch: TCP listener for %s proto", p.proto)
}
default:
return fmt.Errorf("unsupported listener protocol %s", ln.Addr().Network())
}
p.ln = ln
return nil
}
}
// Proxy represents a proxy server.
type proxy struct {
proto string
localAddr string
remoteAddr string
ln net.Listener
conns *connMap
rbufSize, wbufSize int
pauseCh chan struct{}
pauseChMutex *sync.RWMutex
logger *log.Logger
}
type connMap struct {
*sync.RWMutex
m map[string]net.Conn
}
func newConnMap() *connMap {
return &connMap{
RWMutex: new(sync.RWMutex),
m: make(map[string]net.Conn),
}
}
func (m *connMap) Store(c net.Conn) {
m.Lock()
defer m.Unlock()
m.m[c.RemoteAddr().String()] = c
}
func (m *connMap) Delete(c net.Conn) {
m.Lock()
defer m.Unlock()
delete(m.m, c.RemoteAddr().String())
}
// CloseAll closes and deletes all existing connections in the
// connMap. A single lock is held for the duration.
//
// Connections are deleted regardless of whether or not there is an
// error.
func (m *connMap) CloseAll() []error {
m.Lock()
defer m.Unlock()
var errs []error
for _, c := range m.m {
if err := c.Close(); err != nil {
errs = append(errs, err)
}
delete(m.m, c.RemoteAddr().String())
}
return errs
}
// NewProxy creates the proxy, connecting localAddr with remoteAddr.
//
// Currently the only protocol supported is "tcp".
func NewProxy(proto, localAddr, remoteAddr string, opts ...ProxyOption) (*proxy, error) {
// Validate remoteAddr so that we won't run into errors using it
// later.
switch proto {
case "tcp":
if _, err := net.ResolveTCPAddr(proto, remoteAddr); err != nil {
return nil, ErrNewProxy(err)
}
default:
return nil, ErrNewProxy(fmt.Errorf("unsupported protocol %s", proto))
}
p := &proxy{
proto: proto,
localAddr: localAddr,
remoteAddr: remoteAddr,
conns: newConnMap(),
rbufSize: defaultBufferSize,
wbufSize: defaultBufferSize,
pauseCh: func() chan struct{} {
c := make(chan struct{})
close(c)
return c
}(),
pauseChMutex: new(sync.RWMutex),
}
for _, opt := range opts {
if err := opt(p); err != nil {
return nil, ErrNewProxy(err)
}
}
if p.ln != nil {
go func() {
err := p.run()
p.log(err.Error())
}()
}
return p, nil
}
// ListenerAddr gives the address of the listener. If the listener
// has not been started yet, it returns empty.
func (p *proxy) ListenerAddr() string {
if p.ln == nil {
return ""
}
return p.ln.Addr().String()
}
// Start starts the server and returns immediately, as long as the
// listener can be started.
func (p *proxy) Start() error {
if err := p.startListener(); err != nil {
return err
}
go func() {
err := p.run()
p.log(err.Error())
}()
return nil
}
// Run starts the listener, runs the main accept loop, and hands off
// to proxy handlers.
//
// Run blocks until it's done and returns any error from the last
// Accept() on the listener. Use Start() instead if you do not want
// control over this process and would rather just return
// immediately.
func (p *proxy) Run() error {
if err := p.startListener(); err != nil {
return err
}
return p.run()
}
// Close shuts down the listener and all connections.
func (p *proxy) Close() error {
var errs []error
if err := p.CloseListener(); err != nil {
errs = append(errs, err)
}
if err := p.CloseConnections(); err != nil {
errs = append(errs, err)
}
if len(errs) > 0 {
return ErrProxyClose(errs)
}
return nil
}
// CloseListener shuts down the listener. It does not shut down any
// existing connections.
func (p *proxy) CloseListener() error {
if p.ln != nil {
if err := p.ln.Close(); err != nil {
return ErrProxyCloseListener(err)
}
}
return nil
}
// CloseConnections shuts down all existing connections. It does not
// shut down the listener.
func (p *proxy) CloseConnections() error {
errs := p.conns.CloseAll()
if len(errs) > 0 {
return ErrProxyCloseConnections(errs)
}
return nil
}
// Pause re-initializes the internal pause channel and leaves it
// open.
//
// This causes all any and all handlers to stop what they are doing:
// sending and receiving is paused after the most recent copy is
// done, and new connections are blocked after connecting.
//
// Note that any copies that are currently blocked will complete
// before pausing. Consider turning buffers down if you are having
// trouble pausing mid-stream.
//
// Note that it's unsupported and undefined right now to call pause
// twice in a row - this will likely cause some connections to block
// forever and be un-resumable. This will be fixed in later versions.
func (p *proxy) Pause() {
p.pauseChMutex.Lock()
p.pauseCh = make(chan struct{})
p.pauseChMutex.Unlock()
}
// Resume resumes any blocked connections by closing the internal
// pause channel. After this, Pause must be called again to pause
// connections.
//
// Note that calling Resume without pausing the proxy first, or
// calling resume consecutively, will cause a panic.
func (p *proxy) Resume() {
p.pauseChMutex.Lock()
close(p.pauseCh)
p.pauseChMutex.Unlock()
}
func (p *proxy) run() error {
for {
conn, err := p.ln.Accept()
if err != nil {
return ErrProxyRun(err)
}
p.conns.Store(conn)
go func() {
err := p.Handle(conn)
p.log(err.Error())
p.conns.Delete(conn)
}()
}
}
// Handle is the general read-write handler for the connection.
//
// Handle handles connection with and the general read/write loops
// with the remote host.
func (p *proxy) Handle(local net.Conn) error {
defer local.Close()
// Connect to remote
remote, err := net.Dial(p.proto, p.remoteAddr)
if err != nil {
return ErrProxyHandleRemoteConnect(err)
}
defer remote.Close()
errCh := make(chan error)
go func() {
for {
p.pauseChMutex.RLock()
pauseCh := p.pauseCh
p.pauseChMutex.RUnlock()
<-pauseCh
if _, err := io.CopyN(remote, local, int64(p.rbufSize)); err != nil {
errCh <- err
break
}
}
}()
go func() {
for {
p.pauseChMutex.RLock()
pauseCh := p.pauseCh
p.pauseChMutex.RUnlock()
<-pauseCh
if _, err := io.CopyN(local, remote, int64(p.wbufSize)); err != nil {
errCh <- err
break
}
}
}()
err = <-errCh
return ErrProxyHandleStream(err)
}
// startListener starts the listener.
func (p *proxy) startListener() error {
if p.ln != nil {
return ErrProxyListener(errors.New("listener already started"))
}
var err error
p.ln, err = net.Listen(p.proto, p.localAddr)
if err != nil {
return ErrProxyListener(err)
}
return nil
}
// log is an internal function that logs if a logger is present.
func (p *proxy) log(s string) {
if p.logger != nil {
p.logger.Println(s)
}
}