Skip to content

Commit

Permalink
http2: use (*tls.Dialer).DialContext in dialTLS
Browse files Browse the repository at this point in the history
This lets us propagate the request context into the TLS
handshake.

Related to CL 295370
Updates golang/go#32406

Change-Id: Ie10c301be19b57b4b3e46ac31bbe87679e1eebc7
Reviewed-on: https://go-review.googlesource.com/c/net/+/295173
Trust: Johan Brandhorst-Satzkorn <[email protected]>
Run-TryBot: Johan Brandhorst-Satzkorn <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Brad Fitzpatrick <[email protected]>
Reviewed-by: Filippo Valsorda <[email protected]>
  • Loading branch information
johanbrandhorst authored and bradfitz committed May 4, 2021
1 parent 7fd8e65 commit bbd867f
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 49 deletions.
79 changes: 57 additions & 22 deletions http2/client_conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
package http2

import (
"context"
"crypto/tls"
"errors"
"net/http"
"sync"
)
Expand Down Expand Up @@ -78,61 +80,69 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis
// It gets its own connection.
traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse)
cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
if err != nil {
return nil, err
}
return cc, nil
}
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
traceGetConn(req, addr)
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
if st := cc.idleState(); st.canTakeNewRequest {
if p.shouldTraceGetConn(st) {
traceGetConn(req, addr)
}
p.mu.Unlock()
return cc, nil
}
}
if !dialOnMiss {
p.mu.Unlock()
return cc, nil
return nil, ErrNoCachedConn
}
}
if !dialOnMiss {
traceGetConn(req, addr)
call := p.getStartDialLocked(req.Context(), addr)
p.mu.Unlock()
return nil, ErrNoCachedConn
<-call.done
if shouldRetryDial(call, req) {
continue
}
return call.res, call.err
}
traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return call.res, call.err
}

// dialCall is an in-flight Transport dial call to a host.
type dialCall struct {
_ incomparable
p *clientConnPool
_ incomparable
p *clientConnPool
// the context associated with the request
// that created this dialCall
ctx context.Context
done chan struct{} // closed when done
res *ClientConn // valid after done is closed
err error // valid after done is closed
}

// requires p.mu is held.
func (p *clientConnPool) getStartDialLocked(addr string) *dialCall {
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
if call, ok := p.dialing[addr]; ok {
// A dial is already in-flight. Don't start another.
return call
}
call := &dialCall{p: p, done: make(chan struct{})}
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
if p.dialing == nil {
p.dialing = make(map[string]*dialCall)
}
p.dialing[addr] = call
go call.dial(addr)
go call.dial(call.ctx, addr)
return call
}

// run in its own goroutine.
func (c *dialCall) dial(addr string) {
func (c *dialCall) dial(ctx context.Context, addr string) {
const singleUse = false // shared conn
c.res, c.err = c.p.t.dialClientConn(addr, singleUse)
c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
close(c.done)

c.p.mu.Lock()
Expand Down Expand Up @@ -276,3 +286,28 @@ type noDialClientConnPool struct{ *clientConnPool }
func (p noDialClientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
return p.getClientConn(req, addr, noDialOnMiss)
}

// shouldRetryDial reports whether the current request should
// retry dialing after the call finished unsuccessfully, for example
// if the dial was canceled because of a context cancellation or
// deadline expiry.
func shouldRetryDial(call *dialCall, req *http.Request) bool {
if call.err == nil {
// No error, no need to retry
return false
}
if call.ctx == req.Context() {
// If the call has the same context as the request, the dial
// should not be retried, since any cancellation will have come
// from this request.
return false
}
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
// If the call error is not because of a context cancellation or a deadline expiry,
// the dial should not be retried.
return false
}
// Only retry if the error is a context cancellation error or deadline expiry
// and the context associated with the call was canceled or expired.
return call.ctx.Err() != nil
}
42 changes: 18 additions & 24 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,12 @@ func canRetryError(err error) bool {
return false
}

func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
tconn, err := t.dialTLS(ctx)("tcp", addr, t.newTLSConfig(host))
if err != nil {
return nil, err
}
Expand All @@ -590,34 +590,28 @@ func (t *Transport) newTLSConfig(host string) *tls.Config {
return cfg
}

func (t *Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
func (t *Transport) dialTLS(ctx context.Context) func(string, string, *tls.Config) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
return t.dialTLSDefault
}

func (t *Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
cn, err := tls.Dial(network, addr, cfg)
if err != nil {
return nil, err
}
if err := cn.Handshake(); err != nil {
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := cn.VerifyHostname(cfg.ServerName); err != nil {
return func(network, addr string, cfg *tls.Config) (net.Conn, error) {
dialer := &tls.Dialer{
Config: cfg,
}
cn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
state := tlsCn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
return cn, nil
}
state := cn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, NextProtoTLS)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("http2: could not negotiate protocol mutually")
}
return cn, nil
}

// disableKeepAlives reports whether connections should be closed as
Expand Down
169 changes: 169 additions & 0 deletions http2/transport_go117_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build go1.17
// +build go1.17

package http2

import (
"context"
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"

"testing"
)

func TestTransportDialTLSContext(t *testing.T) {
blockCh := make(chan struct{})
serverTLSConfigFunc := func(ts *httptest.Server) {
ts.Config.TLSConfig = &tls.Config{
// Triggers the server to request the clients certificate
// during TLS handshake.
ClientAuth: tls.RequestClientCert,
}
}
ts := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {},
optOnlyServer,
serverTLSConfigFunc,
)
defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Tests that the context provided to `req` is
// passed into this function.
close(blockCh)
<-cri.Context().Done()
return nil, cri.Context().Err()
},
InsecureSkipVerify: true,
},
}
defer tr.CloseIdleConnections()
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
req = req.WithContext(ctx)
errCh := make(chan error)
go func() {
defer close(errCh)
res, err := tr.RoundTrip(req)
if err != nil {
errCh <- err
return
}
res.Body.Close()
}()
// Wait for GetClientCertificate handler to be called
<-blockCh
// Cancel the context
cancel()
// Expect the cancellation error here
err = <-errCh
if err == nil {
t.Fatal("cancelling context during client certificate fetch did not error as expected")
return
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error returned after cancellation: %v", err)
}
}

// TestDialRaceResumesDial tests that, given two concurrent requests
// to the same address, when the first Dial is interrupted because
// the first request's context is cancelled, the second request
// resumes the dial automatically.
func TestDialRaceResumesDial(t *testing.T) {
blockCh := make(chan struct{})
serverTLSConfigFunc := func(ts *httptest.Server) {
ts.Config.TLSConfig = &tls.Config{
// Triggers the server to request the clients certificate
// during TLS handshake.
ClientAuth: tls.RequestClientCert,
}
}
ts := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {},
optOnlyServer,
serverTLSConfigFunc,
)
defer ts.Close()
tr := &Transport{
TLSClientConfig: &tls.Config{
GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
select {
case <-blockCh:
// If we already errored, return without error.
return &tls.Certificate{}, nil
default:
}
close(blockCh)
<-cri.Context().Done()
return nil, cri.Context().Err()
},
InsecureSkipVerify: true,
},
}
defer tr.CloseIdleConnections()
req, err := http.NewRequest(http.MethodGet, ts.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
// Create two requests with independent cancellation.
ctx1, cancel1 := context.WithCancel(context.Background())
defer cancel1()
req1 := req.WithContext(ctx1)
ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()
req2 := req.WithContext(ctx2)
errCh := make(chan error)
go func() {
res, err := tr.RoundTrip(req1)
if err != nil {
errCh <- err
return
}
res.Body.Close()
}()
successCh := make(chan struct{})
go func() {
// Don't start request until first request
// has initiated the handshake.
<-blockCh
res, err := tr.RoundTrip(req2)
if err != nil {
errCh <- err
return
}
res.Body.Close()
// Close successCh to indicate that the second request
// made it to the server successfully.
close(successCh)
}()
// Wait for GetClientCertificate handler to be called
<-blockCh
// Cancel the context first
cancel1()
// Expect the cancellation error here
err = <-errCh
if err == nil {
t.Fatal("cancelling context during client certificate fetch did not error as expected")
return
}
if !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error returned after cancellation: %v", err)
}
select {
case err := <-errCh:
t.Fatalf("unexpected second error: %v", err)
case <-successCh:
}
}
9 changes: 6 additions & 3 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3276,7 +3276,8 @@ func TestClientConnPing(t *testing.T) {
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -4278,7 +4279,8 @@ func testClientConnClose(t *testing.T, closeMode closeMode) {
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -4788,7 +4790,8 @@ func TestTransportRoundtripCloseOnWriteError(t *testing.T) {

tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
ctx := context.Background()
cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit bbd867f

Please sign in to comment.