-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
client: improve connection errors from RPCs and Dial WithReturnConnectionError #4190
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,15 @@ package test | |
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"errors" | ||
"fmt" | ||
"math/big" | ||
"net" | ||
"strings" | ||
"testing" | ||
|
@@ -545,3 +552,177 @@ func (s) TestServerCredsDispatch(t *testing.T) { | |
t.Errorf("Read() = %v, %v; want n>0, <nil>", n, err) | ||
} | ||
} | ||
|
||
type clientCertificates struct { | ||
validCert *tls.Certificate | ||
selfSignedCert *tls.Certificate | ||
expiredCert *tls.Certificate | ||
} | ||
|
||
func (c *clientCertificates) generateCertificates(ca *tls.Certificate) error { | ||
var err error | ||
|
||
c.validCert, err = c.generateCert(ca, time.Now().Add(time.Hour)) | ||
if err != nil { | ||
return fmt.Errorf("failed to generate valid cert: %v", err) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to simplify things significantly by just panicking whenever an error happens inside |
||
} | ||
|
||
c.expiredCert, err = c.generateCert(ca, time.Now().Add(-time.Hour)) | ||
if err != nil { | ||
return fmt.Errorf("failed to generate expired cert: %v", err) | ||
} | ||
|
||
c.selfSignedCert, err = c.generateCert(nil, time.Now().Add(time.Hour)) | ||
if err != nil { | ||
return fmt.Errorf("failed to generate self-signed cert: %v", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *clientCertificates) generateCert(ca *tls.Certificate, notAfter time.Time) (*tls.Certificate, error) { | ||
key, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
return nil, fmt.Errorf("RSA key generation failed: %v", err) | ||
} | ||
|
||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) | ||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to generate serial number: %v", err) | ||
} | ||
|
||
template := &x509.Certificate{ | ||
SerialNumber: serialNumber, | ||
Subject: pkix.Name{ | ||
CommonName: "Test Cert", | ||
}, | ||
NotBefore: time.Now(), | ||
NotAfter: notAfter, | ||
|
||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, | ||
BasicConstraintsValid: true, | ||
} | ||
|
||
// if ca is nil, self-sign the certificate | ||
var ( | ||
parentCert *x509.Certificate = template | ||
parentKey interface{} = key | ||
) | ||
if ca != nil { | ||
parentCert, err = x509.ParseCertificate(ca.Certificate[0]) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to parse CA certificate: %v", err) | ||
} | ||
parentKey = ca.PrivateKey | ||
} | ||
|
||
certData, err := x509.CreateCertificate(rand.Reader, template, parentCert, &key.PublicKey, parentKey) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create certificate: %v", err) | ||
} | ||
|
||
keyPEM := pem.EncodeToMemory(&pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(key), | ||
}) | ||
certPem := pem.EncodeToMemory(&pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: certData, | ||
}) | ||
|
||
cert, err := tls.X509KeyPair(certPem, keyPEM) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to parse public/private certs: %v", err) | ||
} | ||
|
||
return &cert, nil | ||
} | ||
|
||
func (s) TestClientCredsHandshakeFailure(t *testing.T) { | ||
// load server certificate | ||
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// load server CA certificate | ||
ca, err := tls.LoadX509KeyPair(testdata.Path("x509/server_ca_cert.pem"), testdata.Path("x509/server_ca_key.pem")) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// create server tls config | ||
roots := x509.NewCertPool() | ||
roots.AppendCertsFromPEM(pem.EncodeToMemory(&pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: ca.Certificate[0], | ||
})) | ||
|
||
tlsCfg := &tls.Config{ | ||
Certificates: []tls.Certificate{cert}, | ||
ClientCAs: roots, | ||
ClientAuth: tls.RequireAndVerifyClientCert, | ||
} | ||
|
||
// start server listener | ||
lis, err := net.Listen("tcp", "localhost:0") | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
defer lis.Close() | ||
|
||
// start the test server using the credentials | ||
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsCfg))) | ||
go s.Serve(lis) | ||
defer s.Stop() | ||
|
||
// pre-generate client certificates | ||
clientCerts := clientCertificates{} | ||
if err := clientCerts.generateCertificates(&ca); err != nil { | ||
Comment on lines
+682
to
+683
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combine these into |
||
t.Fatal(err) | ||
} | ||
|
||
tests := []struct { | ||
cert *tls.Certificate | ||
shouldFail bool | ||
expectedError string | ||
}{ | ||
{&tls.Certificate{}, true, "remote error: tls: bad certificate"}, | ||
{clientCerts.expiredCert, true, "remote error: tls: bad certificate"}, | ||
{clientCerts.selfSignedCert, true, "remote error: tls: bad certificate"}, | ||
{clientCerts.validCert, false, ""}, | ||
} | ||
|
||
for i, test := range tests { | ||
cfg := &tls.Config{ | ||
Certificates: []tls.Certificate{*test.cert}, | ||
InsecureSkipVerify: true, // not intrested in server certificates | ||
} | ||
creds := credentials.NewTLS(cfg) | ||
|
||
dialCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
cc, err := grpc.DialContext( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please also add a check for RPC errors: cc, err := grpc.Dial(.../* no WithBlock or WithReturnConnectionError*/)
...
client := testpb.NewTestServiceClient(cc)
res, err := client.EmptyCall(...)
// Ensure err contains the connection error You should be able to use the |
||
dialCtx, | ||
lis.Addr().String(), | ||
grpc.WithTransportCredentials(creds), | ||
grpc.WithReturnConnectionError(), | ||
) | ||
if err != nil { | ||
if !test.shouldFail { | ||
t.Fatalf("Test #%d: failed with error %v", i, err) | ||
} else if !strings.Contains(err.Error(), test.expectedError) { | ||
// return error is non-deterministic, therfore just log | ||
t.Logf("Test #%d: error %q does not contain %q", i, err, test.expectedError) | ||
} | ||
} else if err == nil && test.shouldFail { | ||
t.Fatalf("Tesh #%d: should have failed, but it ran successfully.", i) | ||
} | ||
|
||
if cc != nil { | ||
cc.Close() | ||
} | ||
cancel() | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing about my design in #4163 that I preferred is that adding an
error
toClose
means we're forced to inspect every place it is called (now and in the future) so we can improve our error reporting in all instances where the transport is closed. With this alternate implementation, we only optionally improve it. I'm fine with theLastConnectionError()
method (though it is stateful, and, as such, I do prefer the alternative of makingOnClose
accept an error result). But we should really makehttp2Client.Close
accept anerror
. If there is no useful error at the site whereClose
is called, just passErrConnClosing
.If we do that, we also know that whenever
OnClose
is called, there must be a non-nilerror
, and we can simplify theclientconn.go
code to always returnnil, nil, t.LastConnectionError()
(without the nil check).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will go with your idea. Should I keep LastConnectionError or have onClose callback to receive the connection error as parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LastConnectionError
is stateful, meaning it requires synchronization and opens the possibility for a new class of bugs. I think ifOnClose
takes an error to pass the information, it makes for a better flow. (This is "share memory by communicating instead of communicating by sharing memory," it just doesn't require a channel.)