Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http proxy: allow multiple listeners #949

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 103 additions & 56 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"net/http"
"net/url"
"slices"
"sync"
"time"

"github.com/saucelabs/forwarder/hostsfile"
Expand All @@ -28,6 +27,8 @@ import (
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/middleware"
"github.com/saucelabs/forwarder/pac"
"go.uber.org/multierr"
"golang.org/x/sync/errgroup"
)

type ProxyLocalhostMode string
Expand Down Expand Up @@ -80,6 +81,7 @@ var ErrConnectFallback = martian.ErrConnectFallback

type HTTPProxyConfig struct {
HTTPServerConfig
ExtraListeners []NamedListenerConfig
Name string
MITM *MITMConfig
MITMDomains Matcher
Expand Down Expand Up @@ -122,6 +124,11 @@ func (c *HTTPProxyConfig) Validate() error {
if err := c.HTTPServerConfig.Validate(); err != nil {
return err
}
for _, lc := range c.ExtraListeners {
if lc.Name == "" {
return errors.New("extra listener name is required")
}
}
if c.Protocol != HTTPScheme && c.Protocol != HTTPSScheme {
return fmt.Errorf("unsupported protocol: %s", c.Protocol)
}
Expand All @@ -148,7 +155,7 @@ type HTTPProxy struct {
localhost []string

tlsConfig *tls.Config
listener net.Listener
listeners []net.Listener
}

// NewHTTPProxy creates a new HTTP proxy.
Expand All @@ -171,13 +178,15 @@ func NewHTTPProxy(cfg *HTTPProxyConfig, pr PACResolver, cm *CredentialsMatcher,
}
hp.localhost = append(hp.localhost, lh...)

l, err := hp.listen()
ll, err := hp.listen()
if err != nil {
return nil, err
}
hp.listener = l
hp.listeners = ll

hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol)
for _, l := range hp.listeners {
hp.log.Infof("PROXY server listen address=%s protocol=%s", l.Addr(), hp.config.Protocol)
}

return hp, nil
}
Expand Down Expand Up @@ -500,79 +509,117 @@ func (hp *HTTPProxy) handler() http.Handler {
}

func (hp *HTTPProxy) Run(ctx context.Context) error {
var srv *http.Server

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()

<-ctx.Done()
if srv != nil {
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
}
} else {
hp.Close()
}
}()

var srvErr error
if hp.config.TestingHTTPHandler {
hp.log.Infof("using http handler")
srv = &http.Server{
Handler: hp.handler(),
IdleTimeout: hp.config.IdleTimeout,
ReadTimeout: hp.config.ReadTimeout,
ReadHeaderTimeout: hp.config.ReadHeaderTimeout,
WriteTimeout: hp.config.WriteTimeout,
}
srvErr = srv.Serve(hp.listener)
} else {
srvErr = hp.proxy.Serve(hp.listener)
return hp.runHTTPHandler(ctx)
}
return hp.run(ctx)
}

func (hp *HTTPProxy) runHTTPHandler(ctx context.Context) error {
srv := http.Server{
Handler: hp.handler(),
IdleTimeout: hp.config.IdleTimeout,
ReadTimeout: hp.config.ReadTimeout,
ReadHeaderTimeout: hp.config.ReadHeaderTimeout,
WriteTimeout: hp.config.WriteTimeout,
}
if srvErr != nil {
if errors.Is(srvErr, net.ErrClosed) {
srvErr = nil

var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
if err := srv.Shutdown(context.Background()); err != nil {
hp.log.Errorf("failed to shutdown server error=%s", err)
}
return srvErr
return ctx.Err()
})
for i := range hp.listeners {
l := hp.listeners[i]
g.Go(func() error {
err := srv.Serve(l)
if errors.Is(err, http.ErrServerClosed) {
err = nil
}
return err
})
}
return g.Wait()
}

wg.Wait()
return nil
func (hp *HTTPProxy) run(ctx context.Context) error {
var g errgroup.Group
g.Go(func() error {
<-ctx.Done()
hp.Close()
return ctx.Err()
})
for i := range hp.listeners {
l := hp.listeners[i]
g.Go(func() error {
err := hp.proxy.Serve(l)
if errors.Is(err, net.ErrClosed) {
err = nil
}
return err
})
}
return g.Wait()
}

func (hp *HTTPProxy) listen() (net.Listener, error) {
func (hp *HTTPProxy) listen() ([]net.Listener, error) {
switch hp.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
default:
return nil, fmt.Errorf("invalid protocol %q", hp.config.Protocol)
}

l := Listener{
ListenerConfig: hp.config.ListenerConfig,
TLSConfig: hp.tlsConfig,
PromConfig: PromConfig{
PromNamespace: hp.config.PromNamespace,
PromRegistry: hp.config.PromRegistry,
},
}

if err := l.Listen(); err != nil {
return nil, err
if len(hp.config.ExtraListeners) == 0 {
l := &Listener{
ListenerConfig: hp.config.ListenerConfig,
TLSConfig: hp.tlsConfig,
PromConfig: PromConfig{
PromNamespace: hp.config.PromNamespace,
PromRegistry: hp.config.PromRegistry,
},
}
if err := l.Listen(); err != nil {
return nil, err
}
return []net.Listener{l}, nil
}

return &l, nil
return MultiListener{
ListenerConfigs: append([]NamedListenerConfig{{ListenerConfig: hp.config.ListenerConfig}}, hp.config.ExtraListeners...),
TLSConfig: func(lc NamedListenerConfig) *tls.Config {
return hp.tlsConfig
},
PromConfig: hp.config.PromConfig,
}.Listen()
}

// Addr returns the address the server is listening on.
func (hp *HTTPProxy) Addr() string {
return hp.listener.Addr().String()
func (hp *HTTPProxy) Addr() (addrs []string, ok bool) {
addrs = make([]string, len(hp.listeners))
ok = true
for i, l := range hp.listeners {
addrs[i] = l.Addr().String()
if addrs[i] == "" {
ok = false
}
}
return
}

func (hp *HTTPProxy) Close() error {
err := hp.listener.Close()
// Close listeners first to prevent new connections.
var err error
for _, l := range hp.listeners {
if e := l.Close(); e != nil {
err = multierr.Append(err, e)
}
}

// Close the proxy to stop serving requests.
hp.proxy.Close()

if tr, ok := hp.transport.(*http.Transport); ok {
Expand Down
46 changes: 45 additions & 1 deletion net.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,47 @@ func DefaultListenerConfig(addr string) *ListenerConfig {
}
}

type NamedListenerConfig struct {
Name string
ListenerConfig
}

// MultiListener is a builder for multiple listeners sharing the same prometheus configuration.
// The listener name is added as a label to the metrics.
type MultiListener struct {
ListenerConfigs []NamedListenerConfig
TLSConfig func(NamedListenerConfig) *tls.Config
PromConfig
}

func (ml MultiListener) Listen() (_ []net.Listener, ferr error) {
listeners := make([]net.Listener, 0, len(ml.ListenerConfigs))
defer func() {
if ferr != nil {
for _, l := range listeners {
l.Close()
}
}
}()

mf := newListenerMetricsWithNameFunc(ml.PromRegistry, ml.PromNamespace)

for _, lc := range ml.ListenerConfigs {
l := new(Listener)
l.ListenerConfig = lc.ListenerConfig
if ml.TLSConfig != nil {
l.TLSConfig = ml.TLSConfig(lc)
}
l.metrics = mf(lc.Name)
if err := l.Listen(); err != nil {
return nil, err
}
listeners = append(listeners, l)
}

return listeners, nil
}

type Listener struct {
ListenerConfig
TLSConfig *tls.Config
Expand Down Expand Up @@ -222,7 +263,10 @@ func (l *Listener) Listen() error {
}

l.listener = ll
l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace)

if l.metrics == nil {
l.metrics = newListenerMetrics(l.PromRegistry, l.PromNamespace)
}

return nil
}
Expand Down
31 changes: 31 additions & 0 deletions net_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,34 @@ func (m *listenerMetrics) error() {
func (m *listenerMetrics) close() {
m.closed.Inc()
}

func newListenerMetricsWithNameFunc(r prometheus.Registerer, namespace string) func(name string) *listenerMetrics {
if r == nil {
r = prometheus.NewRegistry() // This registry will be discarded.
}
f := promauto.With(r)

accepted := f.NewCounterVec(prometheus.CounterOpts{
Name: "listener_accepted_total",
Namespace: namespace,
Help: "Number of accepted connections",
}, []string{"name"})
errors := f.NewCounterVec(prometheus.CounterOpts{
Name: "listener_errors_total",
Namespace: namespace,
Help: "Number of listener errors when accepting connections",
}, []string{"name"})
closed := f.NewCounterVec(prometheus.CounterOpts{
Name: "listener_closed_total",
Namespace: namespace,
Help: "Number of closed connections",
}, []string{"name"})

return func(name string) *listenerMetrics {
return &listenerMetrics{
accepted: accepted.WithLabelValues(name),
errors: errors.WithLabelValues(name),
closed: closed.WithLabelValues(name),
}
}
}
67 changes: 67 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,70 @@ func selfSingedCert() *tls.Config {
Certificates: []tls.Certificate{cert},
}
}

func (ml *MultiListener) listenAndWait(t *testing.T) []net.Listener {
t.Helper()

listeners, err := ml.Listen()
if err != nil {
t.Fatal(err)
}
for _, l := range listeners {
for {
if l.Addr() != nil {
break
}
}
}
return listeners
}

func TestMultiListenerMetrics(t *testing.T) {
r := prometheus.NewRegistry()
ml := MultiListener{
ListenerConfigs: []NamedListenerConfig{
{
Name: "a",
ListenerConfig: ListenerConfig{
Address: "localhost:0",
},
},
{
Name: "b",
ListenerConfig: ListenerConfig{
Address: "localhost:0",
},
},
},
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
},
}
listeners := ml.listenAndWait(t)
defer func() {
for _, l := range listeners {
l.Close()
}
}()

for _, l := range listeners {
go l.(*Listener).acceptAndCopy() //nolint:forcetypeassert // trust the test
}

for _, l := range listeners {
for range 10 {
conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
fmt.Fprintf(conn, "Hello, World!\n")
if _, err := conn.Read(make([]byte, 1)); err != nil {
t.Fatal(err)
}
conn.Close()
}
}

golden.DiffPrometheusMetrics(t, r)
}
Loading