diff --git a/clientconn.go b/clientconn.go index 77a08fd33bf8..e458135a9a7b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1336,6 +1336,9 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne // We got the preface - huzzah! things are good. case <-onCloseCalled: // The transport has already closed - noop. + if newTr.LastConnectionError() != nil { + return nil, nil, newTr.LastConnectionError() + } return nil, nil, errors.New("connection closed") // TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix. } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 8902b7f90d9d..dad83c1d9990 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -137,6 +137,9 @@ type http2Client struct { bufferPool *bufferPool connectionID uint64 + + lceMu sync.Mutex // protects lastConnectionError + lastConnectionError error } func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { @@ -1309,6 +1312,7 @@ func (t *http2Client) reader() { // Check the validity of server preface. frame, err := t.framer.fr.ReadFrame() if err != nil { + t.updateConnectionError(err) t.Close() // this kicks off resetTransport, so must be last before return return } @@ -1353,6 +1357,7 @@ func (t *http2Client) reader() { } continue } else { + t.updateConnectionError(err) // Transport error. t.Close() return @@ -1525,3 +1530,15 @@ func (t *http2Client) getOutFlowWindow() int64 { return -2 } } + +func (t *http2Client) LastConnectionError() error { + t.lceMu.Lock() + defer t.lceMu.Unlock() + return t.lastConnectionError +} + +func (t *http2Client) updateConnectionError(err error) { + t.lceMu.Lock() + t.lastConnectionError = err + t.lceMu.Unlock() +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 9c8f79cb4b29..40cc7df65853 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -664,6 +664,9 @@ type ClientTransport interface { // IncrMsgRecv increments the number of message received through this transport. IncrMsgRecv() + + // LastConnectionError returns the last recorded connection error. + LastConnectionError() error } // ServerTransport is the common interface for all gRPC server-side transport diff --git a/test/creds_test.go b/test/creds_test.go index 6b3fc2a46076..f167af27d9e5 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -20,8 +20,15 @@ package test import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" + "math/big" "net" "strings" "testing" @@ -545,3 +552,177 @@ func (s) TestServerCredsDispatch(t *testing.T) { t.Errorf("Read() = %v, %v; want n>0, ", n, err) } } + +type clientCertificates struct { + validCert *tls.Certificate + selfSignedCert *tls.Certificate + expiredCert *tls.Certificate +} + +func (c *clientCertificates) generateCertificates(ca *tls.Certificate) error { + var err error + + c.validCert, err = c.generateCert(ca, time.Now().Add(time.Hour)) + if err != nil { + return fmt.Errorf("failed to generate valid cert: %v", err) + } + + c.expiredCert, err = c.generateCert(ca, time.Now().Add(-time.Hour)) + if err != nil { + return fmt.Errorf("failed to generate expired cert: %v", err) + } + + c.selfSignedCert, err = c.generateCert(nil, time.Now().Add(time.Hour)) + if err != nil { + return fmt.Errorf("failed to generate self-signed cert: %v", err) + } + + return nil +} + +func (c *clientCertificates) generateCert(ca *tls.Certificate, notAfter time.Time) (*tls.Certificate, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("RSA key generation failed: %v", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "Test Cert", + }, + NotBefore: time.Now(), + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + // if ca is nil, self-sign the certificate + var ( + parentCert *x509.Certificate = template + parentKey interface{} = key + ) + if ca != nil { + parentCert, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate: %v", err) + } + parentKey = ca.PrivateKey + } + + certData, err := x509.CreateCertificate(rand.Reader, template, parentCert, &key.PublicKey, parentKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + certPem := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certData, + }) + + cert, err := tls.X509KeyPair(certPem, keyPEM) + if err != nil { + return nil, fmt.Errorf("failed to parse public/private certs: %v", err) + } + + return &cert, nil +} + +func (s) TestClientCredsHandshakeFailure(t *testing.T) { + // load server certificate + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + t.Fatal(err) + } + + // load server CA certificate + ca, err := tls.LoadX509KeyPair(testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/server_ca_key.pem")) + if err != nil { + t.Fatal(err) + } + + // create server tls config + roots := x509.NewCertPool() + roots.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: ca.Certificate[0], + })) + + tlsCfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: roots, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + // start server listener + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer lis.Close() + + // start the test server using the credentials + s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsCfg))) + go s.Serve(lis) + defer s.Stop() + + // pre-generate client certificates + clientCerts := clientCertificates{} + if err := clientCerts.generateCertificates(&ca); err != nil { + t.Fatal(err) + } + + tests := []struct { + cert *tls.Certificate + shouldFail bool + expectedError string + }{ + {&tls.Certificate{}, true, "remote error: tls: bad certificate"}, + {clientCerts.expiredCert, true, "remote error: tls: bad certificate"}, + {clientCerts.selfSignedCert, true, "remote error: tls: bad certificate"}, + {clientCerts.validCert, false, ""}, + } + + for i, test := range tests { + cfg := &tls.Config{ + Certificates: []tls.Certificate{*test.cert}, + InsecureSkipVerify: true, // not intrested in server certificates + } + creds := credentials.NewTLS(cfg) + + dialCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + cc, err := grpc.DialContext( + dialCtx, + lis.Addr().String(), + grpc.WithTransportCredentials(creds), + grpc.WithReturnConnectionError(), + ) + if err != nil { + if !test.shouldFail { + t.Fatalf("Test #%d: failed with error %v", i, err) + } else if !strings.Contains(err.Error(), test.expectedError) { + // return error is non-deterministic, therfore just log + t.Logf("Test #%d: error %q does not contain %q", i, err, test.expectedError) + } + } else if err == nil && test.shouldFail { + t.Fatalf("Tesh #%d: should have failed, but it ran successfully.", i) + } + + if cc != nil { + cc.Close() + } + cancel() + } +}