Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: improve connection errors from RPCs and Dial WithReturnConnectionError #4190

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,9 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
// We got the preface - huzzah! things are good.
case <-onCloseCalled:
// The transport has already closed - noop.
if newTr.LastConnectionError() != nil {
return nil, nil, newTr.LastConnectionError()
}
return nil, nil, errors.New("connection closed")
// TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix.
}
Expand Down
17 changes: 17 additions & 0 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type http2Client struct {
bufferPool *bufferPool

connectionID uint64

lceMu sync.Mutex // protects lastConnectionError
lastConnectionError error
}

func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
Expand Down Expand Up @@ -1309,6 +1312,7 @@ func (t *http2Client) reader() {
// Check the validity of server preface.
frame, err := t.framer.fr.ReadFrame()
if err != nil {
t.updateConnectionError(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing about my design in #4163 that I preferred is that adding an error to Close means we're forced to inspect every place it is called (now and in the future) so we can improve our error reporting in all instances where the transport is closed. With this alternate implementation, we only optionally improve it. I'm fine with the LastConnectionError() method (though it is stateful, and, as such, I do prefer the alternative of making OnClose accept an error result). But we should really make http2Client.Close accept an error. If there is no useful error at the site where Close is called, just pass ErrConnClosing.

If we do that, we also know that whenever OnClose is called, there must be a non-nil error, and we can simplify the clientconn.go code to always return nil, nil, t.LastConnectionError() (without the nil check).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will go with your idea. Should I keep LastConnectionError or have onClose callback to receive the connection error as parameter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LastConnectionError is stateful, meaning it requires synchronization and opens the possibility for a new class of bugs. I think if OnClose takes an error to pass the information, it makes for a better flow. (This is "share memory by communicating instead of communicating by sharing memory," it just doesn't require a channel.)

t.Close() // this kicks off resetTransport, so must be last before return
return
}
Expand Down Expand Up @@ -1353,6 +1357,7 @@ func (t *http2Client) reader() {
}
continue
} else {
t.updateConnectionError(err)
// Transport error.
t.Close()
return
Expand Down Expand Up @@ -1525,3 +1530,15 @@ func (t *http2Client) getOutFlowWindow() int64 {
return -2
}
}

func (t *http2Client) LastConnectionError() error {
t.lceMu.Lock()
defer t.lceMu.Unlock()
return t.lastConnectionError
}

func (t *http2Client) updateConnectionError(err error) {
t.lceMu.Lock()
t.lastConnectionError = err
t.lceMu.Unlock()
}
3 changes: 3 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ type ClientTransport interface {

// IncrMsgRecv increments the number of message received through this transport.
IncrMsgRecv()

// LastConnectionError returns the last recorded connection error.
LastConnectionError() error
}

// ServerTransport is the common interface for all gRPC server-side transport
Expand Down
181 changes: 181 additions & 0 deletions test/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ package test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"strings"
"testing"
Expand Down Expand Up @@ -545,3 +552,177 @@ func (s) TestServerCredsDispatch(t *testing.T) {
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err)
}
}

type clientCertificates struct {
validCert *tls.Certificate
selfSignedCert *tls.Certificate
expiredCert *tls.Certificate
}

func (c *clientCertificates) generateCertificates(ca *tls.Certificate) error {
var err error

c.validCert, err = c.generateCert(ca, time.Now().Add(time.Hour))
if err != nil {
return fmt.Errorf("failed to generate valid cert: %v", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to simplify things significantly by just panicking whenever an error happens inside generateCert. Failing "gracefully" for something we expect to pass, during initialization, isn't important.

}

c.expiredCert, err = c.generateCert(ca, time.Now().Add(-time.Hour))
if err != nil {
return fmt.Errorf("failed to generate expired cert: %v", err)
}

c.selfSignedCert, err = c.generateCert(nil, time.Now().Add(time.Hour))
if err != nil {
return fmt.Errorf("failed to generate self-signed cert: %v", err)
}

return nil
}

func (c *clientCertificates) generateCert(ca *tls.Certificate, notAfter time.Time) (*tls.Certificate, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("RSA key generation failed: %v", err)
}

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %v", err)
}

template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "Test Cert",
},
NotBefore: time.Now(),
NotAfter: notAfter,

KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}

// if ca is nil, self-sign the certificate
var (
parentCert *x509.Certificate = template
parentKey interface{} = key
)
if ca != nil {
parentCert, err = x509.ParseCertificate(ca.Certificate[0])
if err != nil {
return nil, fmt.Errorf("failed to parse CA certificate: %v", err)
}
parentKey = ca.PrivateKey
}

certData, err := x509.CreateCertificate(rand.Reader, template, parentCert, &key.PublicKey, parentKey)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %v", err)
}

keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
certPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certData,
})

cert, err := tls.X509KeyPair(certPem, keyPEM)
if err != nil {
return nil, fmt.Errorf("failed to parse public/private certs: %v", err)
}

return &cert, nil
}

func (s) TestClientCredsHandshakeFailure(t *testing.T) {
// load server certificate
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
if err != nil {
t.Fatal(err)
}

// load server CA certificate
ca, err := tls.LoadX509KeyPair(testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/server_ca_key.pem"))
if err != nil {
t.Fatal(err)
}

// create server tls config
roots := x509.NewCertPool()
roots.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: ca.Certificate[0],
}))

tlsCfg := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientCAs: roots,
ClientAuth: tls.RequireAndVerifyClientCert,
}

// start server listener
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
defer lis.Close()

// start the test server using the credentials
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsCfg)))
go s.Serve(lis)
defer s.Stop()

// pre-generate client certificates
clientCerts := clientCertificates{}
if err := clientCerts.generateCertificates(&ca); err != nil {
Comment on lines +682 to +683
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine these into clientCerts, err := newClientCertificates(&ca)?

t.Fatal(err)
}

tests := []struct {
cert *tls.Certificate
shouldFail bool
expectedError string
}{
{&tls.Certificate{}, true, "remote error: tls: bad certificate"},
{clientCerts.expiredCert, true, "remote error: tls: bad certificate"},
{clientCerts.selfSignedCert, true, "remote error: tls: bad certificate"},
{clientCerts.validCert, false, ""},
}

for i, test := range tests {
cfg := &tls.Config{
Certificates: []tls.Certificate{*test.cert},
InsecureSkipVerify: true, // not intrested in server certificates
}
creds := credentials.NewTLS(cfg)

dialCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
cc, err := grpc.DialContext(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also add a check for RPC errors:

cc, err := grpc.Dial(.../* no WithBlock or WithReturnConnectionError*/)
...
client := testpb.NewTestServiceClient(cc)
res, err := client.EmptyCall(...)
// Ensure err contains the connection error

You should be able to use the stubserver.StubServer to make it easy to create the client and server (server and dial options can be provided to the Start method); check the other tests for examples.

dialCtx,
lis.Addr().String(),
grpc.WithTransportCredentials(creds),
grpc.WithReturnConnectionError(),
)
if err != nil {
if !test.shouldFail {
t.Fatalf("Test #%d: failed with error %v", i, err)
} else if !strings.Contains(err.Error(), test.expectedError) {
// return error is non-deterministic, therfore just log
t.Logf("Test #%d: error %q does not contain %q", i, err, test.expectedError)
}
} else if err == nil && test.shouldFail {
t.Fatalf("Tesh #%d: should have failed, but it ran successfully.", i)
}

if cc != nil {
cc.Close()
}
cancel()
}
}