Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

martian: move TLS handshake from forwarder.Listener #941

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func (hp *HTTPProxy) configureProxy() error {
hp.proxy.WithoutWarning = true
hp.proxy.ErrorResponse = hp.errorResponse
hp.proxy.IdleTimeout = hp.config.IdleTimeout
hp.proxy.TLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout
hp.proxy.ReadTimeout = hp.config.ReadTimeout
hp.proxy.ReadHeaderTimeout = hp.config.ReadHeaderTimeout
hp.proxy.WriteTimeout = hp.config.WriteTimeout
Expand Down Expand Up @@ -557,7 +558,6 @@ func (hp *HTTPProxy) listen() (net.Listener, error) {
Log: hp.log,
ProxyProtocolConfig: hp.config.ProxyProtocolConfig,
TLSConfig: hp.tlsConfig,
TLSHandshakeTimeout: hp.config.TLSServerConfig.HandshakeTimeout,
ReadLimit: int64(hp.config.ReadLimit),
WriteLimit: int64(hp.config.WriteLimit),
PromConfig: PromConfig{
Expand Down
10 changes: 10 additions & 0 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ type Proxy struct {
// If both are zero, there is no timeout.
IdleTimeout time.Duration

// TLSHandshakeTimeout is the maximum amount of time to wait for a TLS handshake.
// The proxy will try to cast accepted connections to tls.Conn and perform a handshake.
// If TLSHandshakeTimeout is zero, no timeout is set.
TLSHandshakeTimeout time.Duration

// ReadTimeout is the maximum duration for reading the entire
// request, including the body. A zero or negative value means
// there will be no timeout.
Expand Down Expand Up @@ -257,6 +262,11 @@ func (p *Proxy) handleLoop(conn net.Conn) {

pc := newProxyConn(p, conn)

if err := pc.maybeHandshakeTLS(); err != nil {
log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err)
return
}

const maxConsecutiveErrors = 5
errorsN := 0
for {
Expand Down
25 changes: 20 additions & 5 deletions internal/martian/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,33 @@ type proxyConn struct {
}

func newProxyConn(p *Proxy, conn net.Conn) *proxyConn {
v := &proxyConn{
return &proxyConn{
Proxy: p,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
conn: conn,
}
}

func (p *proxyConn) maybeHandshakeTLS() error {
tconn, ok := p.conn.(*tls.Conn)
if !ok {
return nil
}

if tconn, ok := conn.(*tls.Conn); ok {
v.secure = true
v.cs = tconn.ConnectionState()
ctx := context.Background()
if p.TLSHandshakeTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), p.TLSHandshakeTimeout)
defer cancel()
}
if err := tconn.HandshakeContext(ctx); err != nil {
return err
}

return v
p.secure = true
p.cs = tconn.ConnectionState()

return nil
}

func (p *proxyConn) readRequest() (*http.Request, error) {
Expand Down
31 changes: 31 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,37 @@ func TestIdleTimeout(t *testing.T) {
}
}

func TestTLSHandshakeTimeout(t *testing.T) {
t.Parallel()

l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("net.Listen(): got %v, want no error", err)
}
_, mc := certs(t)
l = tls.NewListener(l, mc.TLS(context.Background()))

h := testHelper{
Listener: l,
Proxy: func(p *Proxy) {
p.TLSHandshakeTimeout = 100 * time.Millisecond
},
}

c, cancel := h.proxyClient(t)
defer cancel()

conn, err := net.Dial("tcp", c.Addr)
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}

time.Sleep(200 * time.Millisecond)
if _, err := conn.Read(make([]byte, 1)); !errors.Is(err, io.EOF) {
t.Fatalf("conn.Read(): got %v, want io.EOF", err)
}
}

func TestReadHeaderTimeout(t *testing.T) {
t.Parallel()

Expand Down
58 changes: 14 additions & 44 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ type Listener struct {
Address string
Log log.Logger
TLSConfig *tls.Config
TLSHandshakeTimeout time.Duration
ProxyProtocolConfig *ProxyProtocolConfig
ReadLimit int64
WriteLimit int64
Expand Down Expand Up @@ -192,55 +191,26 @@ func (l *Listener) Listen() error {
return nil
}

// Accept returns tls.Conn if TLSConfig is set, as martian expects it to be on top.
// Otherwise, it returns forwarder.TrackedConn.
func (l *Listener) Accept() (net.Conn, error) {
for {
conn, err := l.listener.Accept()
if err != nil {
l.metrics.error()
return nil, err
}

if l.TLSConfig == nil {
l.metrics.accept()
return &TrackedConn{
Conn: conn,
OnClose: l.metrics.close,
}, nil
}

tr := &TrackedConn{
Conn: conn,
}
tconn, err := l.withTLS(tr)
if err != nil {
l.Log.Errorf("TLS handshake failed: %v", err)
if cerr := tconn.Close(); cerr != nil {
l.Log.Errorf("error while closing TLS connection after failed handshake: %v", cerr)
}
l.metrics.tlsError()

continue
}

l.metrics.accept()
tr.OnClose = l.metrics.close
return tconn, nil
conn, err := l.listener.Accept()
if err != nil {
l.metrics.error()
return nil, err
}
}

func (l *Listener) withTLS(conn net.Conn) (*tls.Conn, error) {
tconn := tls.Server(conn, l.TLSConfig)
l.metrics.accept()
conn = &TrackedConn{
Conn: conn,
OnClose: l.metrics.close,
}

var err error
if l.TLSHandshakeTimeout <= 0 {
err = tconn.Handshake()
} else {
ctx, cancel := context.WithTimeout(context.Background(), l.TLSHandshakeTimeout)
err = tconn.HandshakeContext(ctx)
cancel()
if l.TLSConfig != nil {
conn = tls.Server(conn, l.TLSConfig)
}

return tconn, err
return conn, nil
}

func (l *Listener) Addr() net.Addr {
Expand Down
16 changes: 3 additions & 13 deletions net_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,9 @@ func addr2Host(addr string) string {
}

type listenerMetrics struct {
accepted prometheus.Counter
errors prometheus.Counter
tlsErrors prometheus.Counter
closed prometheus.Counter
accepted prometheus.Counter
errors prometheus.Counter
closed prometheus.Counter
}

func newListenerMetrics(r prometheus.Registerer, namespace string) *listenerMetrics {
Expand All @@ -105,11 +104,6 @@ func newListenerMetrics(r prometheus.Registerer, namespace string) *listenerMetr
Namespace: namespace,
Help: "Number of listener errors when accepting connections",
}),
tlsErrors: f.NewCounter(prometheus.CounterOpts{
Name: "listener_tls_errors_total",
Namespace: namespace,
Help: "Number of TLS handshake errors",
}),
closed: f.NewCounter(prometheus.CounterOpts{
Name: "listener_closed_total",
Namespace: namespace,
Expand All @@ -126,10 +120,6 @@ func (m *listenerMetrics) error() {
m.errors.Inc()
}

func (m *listenerMetrics) tlsError() {
m.tlsErrors.Inc()
}

func (m *listenerMetrics) close() {
m.closed.Inc()
}
28 changes: 0 additions & 28 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,34 +344,6 @@ func TestListenerMetricsErrors(t *testing.T) {
golden.DiffPrometheusMetrics(t, r)
}

func TestListenerTLSHandshakeTimeout(t *testing.T) {
r := prometheus.NewRegistry()
l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
TLSConfig: selfSingedCert(),
TLSHandshakeTimeout: 100 * time.Millisecond,
PromConfig: PromConfig{
PromNamespace: "test",
PromRegistry: r,
},
}
defer l.Close()

l.listenAndWait(t)
go l.acceptAndCopy()

conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()

time.Sleep(l.TLSHandshakeTimeout * 2)

golden.DiffPrometheusMetrics(t, r)
}

func selfSingedCert() *tls.Config {
ssc := certutil.ECDSASelfSignedCert()
ssc.Hosts = append(ssc.Hosts, "localhost")
Expand Down
3 changes: 0 additions & 3 deletions testdata/TestListenerMetricsAccepted.golden.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@ test_listener_closed_total 10
# HELP test_listener_errors_total Number of listener errors when accepting connections
# TYPE test_listener_errors_total counter
test_listener_errors_total 0
# HELP test_listener_tls_errors_total Number of TLS handshake errors
# TYPE test_listener_tls_errors_total counter
test_listener_tls_errors_total 0
3 changes: 0 additions & 3 deletions testdata/TestListenerMetricsAcceptedWithTLS.golden.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@ test_listener_closed_total 10
# HELP test_listener_errors_total Number of listener errors when accepting connections
# TYPE test_listener_errors_total counter
test_listener_errors_total 0
# HELP test_listener_tls_errors_total Number of TLS handshake errors
# TYPE test_listener_tls_errors_total counter
test_listener_tls_errors_total 0
3 changes: 0 additions & 3 deletions testdata/TestListenerMetricsClosed.golden.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@ test_listener_closed_total 1
# HELP test_listener_errors_total Number of listener errors when accepting connections
# TYPE test_listener_errors_total counter
test_listener_errors_total 0
# HELP test_listener_tls_errors_total Number of TLS handshake errors
# TYPE test_listener_tls_errors_total counter
test_listener_tls_errors_total 0
3 changes: 0 additions & 3 deletions testdata/TestListenerMetricsErrors.golden.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@ test_listener_closed_total 0
# HELP test_listener_errors_total Number of listener errors when accepting connections
# TYPE test_listener_errors_total counter
test_listener_errors_total 1
# HELP test_listener_tls_errors_total Number of TLS handshake errors
# TYPE test_listener_tls_errors_total counter
test_listener_tls_errors_total 0
12 changes: 0 additions & 12 deletions testdata/TestListenerTLSHandshakeTimeout.golden.txt

This file was deleted.