From d5e602139b51e951a0212b88488f9fc1326a8739 Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Mon, 14 Oct 2024 12:29:26 +0200 Subject: [PATCH 1/2] net: remove TLS handshake timeout from listener Listener is not the right place to do the handshake. It may block main accept loop. --- http_proxy.go | 1 - net.go | 58 +++++-------------- net_metrics.go | 16 +---- net_test.go | 28 --------- .../TestListenerMetricsAccepted.golden.txt | 3 - ...tListenerMetricsAcceptedWithTLS.golden.txt | 3 - testdata/TestListenerMetricsClosed.golden.txt | 3 - testdata/TestListenerMetricsErrors.golden.txt | 3 - ...TestListenerTLSHandshakeTimeout.golden.txt | 12 ---- 9 files changed, 17 insertions(+), 110 deletions(-) delete mode 100644 testdata/TestListenerTLSHandshakeTimeout.golden.txt diff --git a/http_proxy.go b/http_proxy.go index 3ae445bf..7564844a 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -557,7 +557,6 @@ func (hp *HTTPProxy) listen() (net.Listener, error) { Log: hp.log, ProxyProtocolConfig: hp.config.ProxyProtocolConfig, TLSConfig: hp.tlsConfig, - TLSHandshakeTimeout: hp.config.TLSServerConfig.HandshakeTimeout, ReadLimit: int64(hp.config.ReadLimit), WriteLimit: int64(hp.config.WriteLimit), PromConfig: PromConfig{ diff --git a/net.go b/net.go index 5b60ef24..a31d3b13 100644 --- a/net.go +++ b/net.go @@ -155,7 +155,6 @@ type Listener struct { Address string Log log.Logger TLSConfig *tls.Config - TLSHandshakeTimeout time.Duration ProxyProtocolConfig *ProxyProtocolConfig ReadLimit int64 WriteLimit int64 @@ -192,55 +191,26 @@ func (l *Listener) Listen() error { return nil } +// Accept returns tls.Conn if TLSConfig is set, as martian expects it to be on top. +// Otherwise, it returns forwarder.TrackedConn. func (l *Listener) Accept() (net.Conn, error) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.metrics.error() - return nil, err - } - - if l.TLSConfig == nil { - l.metrics.accept() - return &TrackedConn{ - Conn: conn, - OnClose: l.metrics.close, - }, nil - } - - tr := &TrackedConn{ - Conn: conn, - } - tconn, err := l.withTLS(tr) - if err != nil { - l.Log.Errorf("TLS handshake failed: %v", err) - if cerr := tconn.Close(); cerr != nil { - l.Log.Errorf("error while closing TLS connection after failed handshake: %v", cerr) - } - l.metrics.tlsError() - - continue - } - - l.metrics.accept() - tr.OnClose = l.metrics.close - return tconn, nil + conn, err := l.listener.Accept() + if err != nil { + l.metrics.error() + return nil, err } -} -func (l *Listener) withTLS(conn net.Conn) (*tls.Conn, error) { - tconn := tls.Server(conn, l.TLSConfig) + l.metrics.accept() + conn = &TrackedConn{ + Conn: conn, + OnClose: l.metrics.close, + } - var err error - if l.TLSHandshakeTimeout <= 0 { - err = tconn.Handshake() - } else { - ctx, cancel := context.WithTimeout(context.Background(), l.TLSHandshakeTimeout) - err = tconn.HandshakeContext(ctx) - cancel() + if l.TLSConfig != nil { + conn = tls.Server(conn, l.TLSConfig) } - return tconn, err + return conn, nil } func (l *Listener) Addr() net.Addr { diff --git a/net_metrics.go b/net_metrics.go index d1602236..574fab4c 100644 --- a/net_metrics.go +++ b/net_metrics.go @@ -82,10 +82,9 @@ func addr2Host(addr string) string { } type listenerMetrics struct { - accepted prometheus.Counter - errors prometheus.Counter - tlsErrors prometheus.Counter - closed prometheus.Counter + accepted prometheus.Counter + errors prometheus.Counter + closed prometheus.Counter } func newListenerMetrics(r prometheus.Registerer, namespace string) *listenerMetrics { @@ -105,11 +104,6 @@ func newListenerMetrics(r prometheus.Registerer, namespace string) *listenerMetr Namespace: namespace, Help: "Number of listener errors when accepting connections", }), - tlsErrors: f.NewCounter(prometheus.CounterOpts{ - Name: "listener_tls_errors_total", - Namespace: namespace, - Help: "Number of TLS handshake errors", - }), closed: f.NewCounter(prometheus.CounterOpts{ Name: "listener_closed_total", Namespace: namespace, @@ -126,10 +120,6 @@ func (m *listenerMetrics) error() { m.errors.Inc() } -func (m *listenerMetrics) tlsError() { - m.tlsErrors.Inc() -} - func (m *listenerMetrics) close() { m.closed.Inc() } diff --git a/net_test.go b/net_test.go index f152c28f..17e5b480 100644 --- a/net_test.go +++ b/net_test.go @@ -344,34 +344,6 @@ func TestListenerMetricsErrors(t *testing.T) { golden.DiffPrometheusMetrics(t, r) } -func TestListenerTLSHandshakeTimeout(t *testing.T) { - r := prometheus.NewRegistry() - l := Listener{ - Address: "localhost:0", - Log: log.NopLogger, - TLSConfig: selfSingedCert(), - TLSHandshakeTimeout: 100 * time.Millisecond, - PromConfig: PromConfig{ - PromNamespace: "test", - PromRegistry: r, - }, - } - defer l.Close() - - l.listenAndWait(t) - go l.acceptAndCopy() - - conn, err := net.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("net.Dial(): got %v, want no error", err) - } - defer conn.Close() - - time.Sleep(l.TLSHandshakeTimeout * 2) - - golden.DiffPrometheusMetrics(t, r) -} - func selfSingedCert() *tls.Config { ssc := certutil.ECDSASelfSignedCert() ssc.Hosts = append(ssc.Hosts, "localhost") diff --git a/testdata/TestListenerMetricsAccepted.golden.txt b/testdata/TestListenerMetricsAccepted.golden.txt index fb7a12b0..65743b11 100644 --- a/testdata/TestListenerMetricsAccepted.golden.txt +++ b/testdata/TestListenerMetricsAccepted.golden.txt @@ -7,6 +7,3 @@ test_listener_closed_total 10 # HELP test_listener_errors_total Number of listener errors when accepting connections # TYPE test_listener_errors_total counter test_listener_errors_total 0 -# HELP test_listener_tls_errors_total Number of TLS handshake errors -# TYPE test_listener_tls_errors_total counter -test_listener_tls_errors_total 0 diff --git a/testdata/TestListenerMetricsAcceptedWithTLS.golden.txt b/testdata/TestListenerMetricsAcceptedWithTLS.golden.txt index fb7a12b0..65743b11 100644 --- a/testdata/TestListenerMetricsAcceptedWithTLS.golden.txt +++ b/testdata/TestListenerMetricsAcceptedWithTLS.golden.txt @@ -7,6 +7,3 @@ test_listener_closed_total 10 # HELP test_listener_errors_total Number of listener errors when accepting connections # TYPE test_listener_errors_total counter test_listener_errors_total 0 -# HELP test_listener_tls_errors_total Number of TLS handshake errors -# TYPE test_listener_tls_errors_total counter -test_listener_tls_errors_total 0 diff --git a/testdata/TestListenerMetricsClosed.golden.txt b/testdata/TestListenerMetricsClosed.golden.txt index a2cb6d92..0f333e93 100644 --- a/testdata/TestListenerMetricsClosed.golden.txt +++ b/testdata/TestListenerMetricsClosed.golden.txt @@ -7,6 +7,3 @@ test_listener_closed_total 1 # HELP test_listener_errors_total Number of listener errors when accepting connections # TYPE test_listener_errors_total counter test_listener_errors_total 0 -# HELP test_listener_tls_errors_total Number of TLS handshake errors -# TYPE test_listener_tls_errors_total counter -test_listener_tls_errors_total 0 diff --git a/testdata/TestListenerMetricsErrors.golden.txt b/testdata/TestListenerMetricsErrors.golden.txt index e6be8bf7..e84ad5eb 100644 --- a/testdata/TestListenerMetricsErrors.golden.txt +++ b/testdata/TestListenerMetricsErrors.golden.txt @@ -7,6 +7,3 @@ test_listener_closed_total 0 # HELP test_listener_errors_total Number of listener errors when accepting connections # TYPE test_listener_errors_total counter test_listener_errors_total 1 -# HELP test_listener_tls_errors_total Number of TLS handshake errors -# TYPE test_listener_tls_errors_total counter -test_listener_tls_errors_total 0 diff --git a/testdata/TestListenerTLSHandshakeTimeout.golden.txt b/testdata/TestListenerTLSHandshakeTimeout.golden.txt deleted file mode 100644 index 355610b8..00000000 --- a/testdata/TestListenerTLSHandshakeTimeout.golden.txt +++ /dev/null @@ -1,12 +0,0 @@ -# HELP test_listener_accepted_total Number of accepted connections -# TYPE test_listener_accepted_total counter -test_listener_accepted_total 0 -# HELP test_listener_closed_total Number of closed connections -# TYPE test_listener_closed_total counter -test_listener_closed_total 0 -# HELP test_listener_errors_total Number of listener errors when accepting connections -# TYPE test_listener_errors_total counter -test_listener_errors_total 0 -# HELP test_listener_tls_errors_total Number of TLS handshake errors -# TYPE test_listener_tls_errors_total counter -test_listener_tls_errors_total 1 From 06a73d539889c068db8242f6221af20be57e9eff Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Mon, 14 Oct 2024 13:05:52 +0200 Subject: [PATCH 2/2] martian: add support for TLSHandshakeTimeout Martian Proxy will explicitly do the handshake if accepted connection is tls.Conn. --- http_proxy.go | 1 + internal/martian/proxy.go | 10 ++++++++++ internal/martian/proxy_conn.go | 25 ++++++++++++++++++++----- internal/martian/proxy_test.go | 31 +++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/http_proxy.go b/http_proxy.go index 7564844a..a8641859 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -250,6 +250,7 @@ func (hp *HTTPProxy) configureProxy() error { hp.proxy.WithoutWarning = true hp.proxy.ErrorResponse = hp.errorResponse hp.proxy.IdleTimeout = hp.config.IdleTimeout + hp.proxy.TLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout hp.proxy.ReadTimeout = hp.config.ReadTimeout hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout hp.proxy.WriteTimeout = hp.config.WriteTimeout diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 86a4c9a0..5682380a 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -83,6 +83,11 @@ type Proxy struct { // If both are zero, there is no timeout. IdleTimeout time.Duration + // TLSHandshakeTimeout is the maximum amount of time to wait for a TLS handshake. + // The proxy will try to cast accepted connections to tls.Conn and perform a handshake. + // If TLSHandshakeTimeout is zero, no timeout is set. + TLSHandshakeTimeout time.Duration + // ReadTimeout is the maximum duration for reading the entire // request, including the body. A zero or negative value means // there will be no timeout. @@ -257,6 +262,11 @@ func (p *Proxy) handleLoop(conn net.Conn) { pc := newProxyConn(p, conn) + if err := pc.maybeHandshakeTLS(); err != nil { + log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err) + return + } + const maxConsecutiveErrors = 5 errorsN := 0 for { diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index 232909df..2a603152 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -41,18 +41,33 @@ type proxyConn struct { } func newProxyConn(p *Proxy, conn net.Conn) *proxyConn { - v := &proxyConn{ + return &proxyConn{ Proxy: p, brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), conn: conn, } +} + +func (p *proxyConn) maybeHandshakeTLS() error { + tconn, ok := p.conn.(*tls.Conn) + if !ok { + return nil + } - if tconn, ok := conn.(*tls.Conn); ok { - v.secure = true - v.cs = tconn.ConnectionState() + ctx := context.Background() + if p.TLSHandshakeTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), p.TLSHandshakeTimeout) + defer cancel() + } + if err := tconn.HandshakeContext(ctx); err != nil { + return err } - return v + p.secure = true + p.cs = tconn.ConnectionState() + + return nil } func (p *proxyConn) readRequest() (*http.Request, error) { diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index d51d5c86..ba3bc7da 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -1838,6 +1838,37 @@ func TestIdleTimeout(t *testing.T) { } } +func TestTLSHandshakeTimeout(t *testing.T) { + t.Parallel() + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(): got %v, want no error", err) + } + _, mc := certs(t) + l = tls.NewListener(l, mc.TLS(context.Background())) + + h := testHelper{ + Listener: l, + Proxy: func(p *Proxy) { + p.TLSHandshakeTimeout = 100 * time.Millisecond + }, + } + + c, cancel := h.proxyClient(t) + defer cancel() + + conn, err := net.Dial("tcp", c.Addr) + if err != nil { + t.Fatalf("net.Dial(): got %v, want no error", err) + } + + time.Sleep(200 * time.Millisecond) + if _, err := conn.Read(make([]byte, 1)); !errors.Is(err, io.EOF) { + t.Fatalf("conn.Read(): got %v, want io.EOF", err) + } +} + func TestReadHeaderTimeout(t *testing.T) { t.Parallel()