diff --git a/http_proxy.go b/http_proxy.go index 448a6a88..52a87c3d 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -17,7 +17,6 @@ import ( "net/http" "net/url" "slices" - "sync" "time" "github.com/saucelabs/forwarder/hostsfile" @@ -28,6 +27,8 @@ import ( "github.com/saucelabs/forwarder/log" "github.com/saucelabs/forwarder/middleware" "github.com/saucelabs/forwarder/pac" + "go.uber.org/multierr" + "golang.org/x/sync/errgroup" ) type ProxyLocalhostMode string @@ -80,6 +81,7 @@ var ErrConnectFallback = martian.ErrConnectFallback type HTTPProxyConfig struct { HTTPServerConfig + ExtraListeners []NamedListenerConfig Name string MITM *MITMConfig MITMDomains Matcher @@ -122,6 +124,11 @@ func (c *HTTPProxyConfig) Validate() error { if err := c.HTTPServerConfig.Validate(); err != nil { return err } + for _, lc := range c.ExtraListeners { + if lc.Name == "" { + return errors.New("extra listener name is required") + } + } if c.Protocol != HTTPScheme && c.Protocol != HTTPSScheme { return fmt.Errorf("unsupported protocol: %s", c.Protocol) } @@ -148,7 +155,7 @@ type HTTPProxy struct { localhost []string tlsConfig *tls.Config - listener net.Listener + listeners []net.Listener } // NewHTTPProxy creates a new HTTP proxy. @@ -171,13 +178,15 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher, } hp.localhost = append(hp.localhost, lh...) - l, err := hp.listen() + ll, err := hp.listen() if err != nil { return nil, err } - hp.listener = l + hp.listeners = ll - hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol) + for _, l := range hp.listeners { + hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol) + } return hp, nil } @@ -500,79 +509,117 @@ func (hp *HTTPProxy) handler() http.Handler { } func (hp *HTTPProxy) Run(ctx context.Context) error { - var srv *http.Server - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - - <-ctx.Done() - if srv != nil { - if err := srv.Shutdown(context.Background()); err != nil { - hp.log.Errorf("failed to shutdown server error=%s", err) - } - } else { - hp.Close() - } - }() - - var srvErr error if hp.config.TestingHTTPHandler { hp.log.Infof("using http handler") - srv = &http.Server{ - Handler: hp.handler(), - IdleTimeout: hp.config.IdleTimeout, - ReadTimeout: hp.config.ReadTimeout, - ReadHeaderTimeout: hp.config.ReadHeaderTimeout, - WriteTimeout: hp.config.WriteTimeout, - } - srvErr = srv.Serve(hp.listener) - } else { - srvErr = hp.proxy.Serve(hp.listener) + return hp.runHTTPHandler(ctx) + } + return hp.run(ctx) +} + +func (hp *HTTPProxy) runHTTPHandler(ctx context.Context) error { + srv := http.Server{ + Handler: hp.handler(), + IdleTimeout: hp.config.IdleTimeout, + ReadTimeout: hp.config.ReadTimeout, + ReadHeaderTimeout: hp.config.ReadHeaderTimeout, + WriteTimeout: hp.config.WriteTimeout, } - if srvErr != nil { - if errors.Is(srvErr, net.ErrClosed) { - srvErr = nil + + var g errgroup.Group + g.Go(func() error { + <-ctx.Done() + if err := srv.Shutdown(context.Background()); err != nil { + hp.log.Errorf("failed to shutdown server error=%s", err) } - return srvErr + return ctx.Err() + }) + for i := range hp.listeners { + l := hp.listeners[i] + g.Go(func() error { + err := srv.Serve(l) + if errors.Is(err, http.ErrServerClosed) { + err = nil + } + return err + }) } + return g.Wait() +} - wg.Wait() - return nil +func (hp *HTTPProxy) run(ctx context.Context) error { + var g errgroup.Group + g.Go(func() error { + <-ctx.Done() + hp.Close() + return ctx.Err() + }) + for i := range hp.listeners { + l := hp.listeners[i] + g.Go(func() error { + err := hp.proxy.Serve(l) + if errors.Is(err, net.ErrClosed) { + err = nil + } + return err + }) + } + return g.Wait() } -func (hp *HTTPProxy) listen() (net.Listener, error) { +func (hp *HTTPProxy) listen() ([]net.Listener, error) { switch hp.config.Protocol { case HTTPScheme, HTTPSScheme, HTTP2Scheme: default: return nil, fmt.Errorf("invalid protocol %q", hp.config.Protocol) } - l := Listener{ - ListenerConfig: hp.config.ListenerConfig, - TLSConfig: hp.tlsConfig, - PromConfig: PromConfig{ - PromNamespace: hp.config.PromNamespace, - PromRegistry: hp.config.PromRegistry, - }, - } - - if err := l.Listen(); err != nil { - return nil, err + if len(hp.config.ExtraListeners) == 0 { + l := &Listener{ + ListenerConfig: hp.config.ListenerConfig, + TLSConfig: hp.tlsConfig, + PromConfig: PromConfig{ + PromNamespace: hp.config.PromNamespace, + PromRegistry: hp.config.PromRegistry, + }, + } + if err := l.Listen(); err != nil { + return nil, err + } + return []net.Listener{l}, nil } - return &l, nil + return MultiListener{ + ListenerConfigs: append([]NamedListenerConfig{{ListenerConfig: hp.config.ListenerConfig}}, hp.config.ExtraListeners...), + TLSConfig: func(lc NamedListenerConfig) *tls.Config { + return hp.tlsConfig + }, + PromConfig: hp.config.PromConfig, + }.Listen() } // Addr returns the address the server is listening on. -func (hp *HTTPProxy) Addr() string { - return hp.listener.Addr().String() +func (hp *HTTPProxy) Addr() (addrs []string, ok bool) { + addrs = make([]string, len(hp.listeners)) + ok = true + for i, l := range hp.listeners { + addrs[i] = l.Addr().String() + if addrs[i] == "" { + ok = false + } + } + return } func (hp *HTTPProxy) Close() error { - err := hp.listener.Close() + // Close listeners first to prevent new connections. + var err error + for _, l := range hp.listeners { + if e := l.Close(); e != nil { + err = multierr.Append(err, e) + } + } + + // Close the proxy to stop serving requests. hp.proxy.Close() if tr, ok := hp.transport.(*http.Transport); ok { diff --git a/net.go b/net.go index 3b6365de..04fbcc17 100644 --- a/net.go +++ b/net.go @@ -191,6 +191,47 @@ func DefaultListenerConfig(addr string) *ListenerConfig { } } +type NamedListenerConfig struct { + Name string + ListenerConfig +} + +// MultiListener is a builder for multiple listeners sharing the same prometheus configuration. +// The listener name is added as a label to the metrics. +type MultiListener struct { + ListenerConfigs []NamedListenerConfig + TLSConfig func(NamedListenerConfig) *tls.Config + PromConfig +} + +func (ml MultiListener) Listen() (_ []net.Listener, ferr error) { + listeners := make([]net.Listener, 0, len(ml.ListenerConfigs)) + defer func() { + if ferr != nil { + for _, l := range listeners { + l.Close() + } + } + }() + + mf := newListenerMetricsWithNameFunc(ml.PromRegistry, ml.PromNamespace) + + for _, lc := range ml.ListenerConfigs { + l := new(Listener) + l.ListenerConfig = lc.ListenerConfig + if ml.TLSConfig != nil { + l.TLSConfig = ml.TLSConfig(lc) + } + l.metrics = mf(lc.Name) + if err := l.Listen(); err != nil { + return nil, err + } + listeners = append(listeners, l) + } + + return listeners, nil +} + type Listener struct { ListenerConfig TLSConfig *tls.Config @@ -222,7 +263,10 @@ func (l *Listener) Listen() error { } l.listener = ll - l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace) + + if l.metrics == nil { + l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace) + } return nil } diff --git a/net_metrics.go b/net_metrics.go index 574fab4c..a01c0993 100644 --- a/net_metrics.go +++ b/net_metrics.go @@ -123,3 +123,34 @@ func (m *listenerMetrics) error() { func (m *listenerMetrics) close() { m.closed.Inc() } + +func newListenerMetricsWithNameFunc(r prometheus.Registerer, namespace string) func(name string) *listenerMetrics { + if r == nil { + r = prometheus.NewRegistry() // This registry will be discarded. + } + f := promauto.With(r) + + accepted := f.NewCounterVec(prometheus.CounterOpts{ + Name: "listener_accepted_total", + Namespace: namespace, + Help: "Number of accepted connections", + }, []string{"name"}) + errors := f.NewCounterVec(prometheus.CounterOpts{ + Name: "listener_errors_total", + Namespace: namespace, + Help: "Number of listener errors when accepting connections", + }, []string{"name"}) + closed := f.NewCounterVec(prometheus.CounterOpts{ + Name: "listener_closed_total", + Namespace: namespace, + Help: "Number of closed connections", + }, []string{"name"}) + + return func(name string) *listenerMetrics { + return &listenerMetrics{ + accepted: accepted.WithLabelValues(name), + errors: errors.WithLabelValues(name), + closed: closed.WithLabelValues(name), + } + } +} diff --git a/net_test.go b/net_test.go index c2519bb4..29eef53d 100644 --- a/net_test.go +++ b/net_test.go @@ -394,3 +394,70 @@ func selfSingedCert() *tls.Config { Certificates: []tls.Certificate{cert}, } } + +func (ml *MultiListener) listenAndWait(t *testing.T) []net.Listener { + t.Helper() + + listeners, err := ml.Listen() + if err != nil { + t.Fatal(err) + } + for _, l := range listeners { + for { + if l.Addr() != nil { + break + } + } + } + return listeners +} + +func TestMultiListenerMetrics(t *testing.T) { + r := prometheus.NewRegistry() + ml := MultiListener{ + ListenerConfigs: []NamedListenerConfig{ + { + Name: "a", + ListenerConfig: ListenerConfig{ + Address: "localhost:0", + }, + }, + { + Name: "b", + ListenerConfig: ListenerConfig{ + Address: "localhost:0", + }, + }, + }, + PromConfig: PromConfig{ + PromNamespace: "test", + PromRegistry: r, + }, + } + listeners := ml.listenAndWait(t) + defer func() { + for _, l := range listeners { + l.Close() + } + }() + + for _, l := range listeners { + go l.(*Listener).acceptAndCopy() //nolint:forcetypeassert // trust the test + } + + for _, l := range listeners { + for range 10 { + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("net.Dial(): got %v, want no error", err) + } + fmt.Fprintf(conn, "Hello, World!\n") + if _, err := conn.Read(make([]byte, 1)); err != nil { + t.Fatal(err) + } + conn.Close() + } + } + + golden.DiffPrometheusMetrics(t, r) +} diff --git a/testdata/TestMultiListenerMetrics.golden.txt b/testdata/TestMultiListenerMetrics.golden.txt new file mode 100644 index 00000000..428e55b2 --- /dev/null +++ b/testdata/TestMultiListenerMetrics.golden.txt @@ -0,0 +1,12 @@ +# HELP test_listener_accepted_total Number of accepted connections +# TYPE test_listener_accepted_total counter +test_listener_accepted_total{name="a"} 10 +test_listener_accepted_total{name="b"} 10 +# HELP test_listener_closed_total Number of closed connections +# TYPE test_listener_closed_total counter +test_listener_closed_total{name="a"} 10 +test_listener_closed_total{name="b"} 10 +# HELP test_listener_errors_total Number of listener errors when accepting connections +# TYPE test_listener_errors_total counter +test_listener_errors_total{name="a"} 0 +test_listener_errors_total{name="b"} 0