Skip to content

Commit

Permalink
Properly order for both TLS and non-TLS the encapsulation.
Browse files Browse the repository at this point in the history
Previous patch was using conn to write magic number, but
since TLS Connection had been started, if corrupted the stream when
TLS was active.

To make it work, changed `TLSWrapper` to be another indirection using
the datacenter as paramter, thus, the TLS knowledge about TLS being
enable is known in advance.

This involves some refactoring, but we are sure the condtions to
decide whether TLS is enabled are identical.
  • Loading branch information
pierresouchay committed Jan 5, 2021
1 parent c3e172d commit deeb577
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .changelog/9494.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
client: properly set GRPC over RPC magic numbers when encryption was not set or partially set in the cluster with streaming enabled
```
8 changes: 6 additions & 2 deletions agent/consul/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ func TestRPC_PreventsTLSNesting(t *testing.T) {

// Start tls client
tlsWrap := s1.tlsConfigurator.OutgoingRPCWrapper()
tlsConn, err := tlsWrap("dc1", conn)
tlsConnWrapper, useTLS := tlsWrap("dc1", conn)
require.True(t, useTLS)
tlsConn, err := tlsConnWrapper(conn)
require.NoError(t, err)

// Write Inner magic byte
Expand Down Expand Up @@ -617,7 +619,9 @@ func connectClient(t *testing.T, s1 *Server, mb pool.RPCType, useTLS, wantOpen b
require.NoError(t, err)

if useTLS {
tlsConn, err := tlsWrap(s1.config.Datacenter, conn)
tlsConnW, useTLS := tlsWrap(s1.config.Datacenter, conn)
require.Equal(t, useTLS, useTLS)
tlsConn, err := tlsConnW(conn)
// Subtly, tlsWrap will NOT actually do a handshake in this case - it only
// does so for some configs, so even if the server closed the conn before
// handshake this won't fail and it's only when we attempt to read or write
Expand Down
3 changes: 2 additions & 1 deletion agent/consul/status_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) {
}

// Wrap the connection in a TLS client
tlsConn, err := wrapper(s.config.Datacenter, conn)
tlsConnW, _ := wrapper(s.config.Datacenter, conn)
tlsConn, err := tlsConnW(conn)
if err != nil {
conn.Close()
return nil, err
Expand Down
28 changes: 11 additions & 17 deletions agent/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ type ServerLocator interface {
}

// TLSWrapper wraps a non-TLS connection and returns a connection with TLS
// enabled.
type TLSWrapper func(dc string, conn net.Conn) (net.Conn, error)
// possibly enabled.
// The bool indicates whether the returned function will use TLS.
type TLSWrapper func(dc string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool)

type dialer func(context.Context, string) (net.Conn, error)

Expand Down Expand Up @@ -74,7 +75,7 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error)

// newDialer returns a gRPC dialer function that conditionally wraps the connection
// with TLS based on the Server.useTLS value.
func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context, string) (net.Conn, error) {
func newDialer(servers ServerLocator, tlsWrapper TLSWrapper) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", addr)
Expand All @@ -87,27 +88,20 @@ func newDialer(servers ServerLocator, wrapper TLSWrapper) func(context.Context,
conn.Close()
return nil, err
}

if server.UseTLS {
if wrapper == nil {
wrapper, tlsEnabled := tlsWrapper(server.Datacenter, conn)
if server.UseTLS && tlsEnabled {
// If connection is upgraded to TLS, mark the stream as RPCTLS
if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil {
conn.Close()
return nil, fmt.Errorf("TLS enabled but got nil TLS wrapper")
return nil, err
}

// Wrap the connection in a TLS client, return same conn if TLS disabled
tlsConn, err := wrapper(server.Datacenter, conn)
tlsConn, err := wrapper(conn)
if err != nil {
conn.Close()
return nil, err
}
if tlsConn != conn {
// If connection is upgraded to TLS, mark the stream as RPCTLS
if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil {
conn.Close()
return nil, err
}
conn = tlsConn
}
conn = tlsConn
}

_, err = conn.Write([]byte{pool.RPCGRPC})
Expand Down
22 changes: 17 additions & 5 deletions agent/grpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) {
})

var called bool
wrapper := func(_ string, conn net.Conn) (net.Conn, error) {
wrapper := func(_ string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool) {
called = true
return conn, nil
fn := func(conn net.Conn) (net.Conn, error) {
return conn, nil
}
return fn, true
}
dial := newDialer(builder, wrapper)
ctx := context.Background()
Expand Down Expand Up @@ -78,11 +81,20 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) {
require.True(t, atomic.LoadInt32(&srv.rpc.tlsConnEstablished) > 0)
}

// noOpTLSWrapper Generate a TLVWrapper that does not encypt anything, but is not nil
func noOpTLSWrapper() func(dc string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool) {
return func(dc string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool) {
return func(net.Conn) (net.Conn, error) {
return conn, nil
}, false
}
}

func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) {
count := 4
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, noOpTLSWrapper())

for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
Expand Down Expand Up @@ -119,7 +131,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Rebalance(t *testing.T) {
count := 5
res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, noOpTLSWrapper())

for i := 0; i < count; i++ {
name := fmt.Sprintf("server-%d", i)
Expand Down Expand Up @@ -168,7 +180,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) {

res := resolver.NewServerResolverBuilder(newConfig(t))
registerWithGRPC(t, res)
pool := NewClientConnPool(res, nil)
pool := NewClientConnPool(res, noOpTLSWrapper())

for _, dc := range dcs {
name := "server-0-" + dc
Expand Down
2 changes: 1 addition & 1 deletion agent/grpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (f *fakeRPCListener) handleConn(conn net.Conn) {
f.handleConn(conn)

default:
fmt.Println("ERROR: unexpected byte", typ)
fmt.Println("server_test ERROR: unexpected byte", typ)
conn.Close()
}
}
Expand Down
3 changes: 2 additions & 1 deletion agent/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ func (p *ConnPool) dial(
}

// Wrap the connection in a TLS client
tlsConn, err := wrapper(dc, conn)
tlsConnBuilder, _ := wrapper(dc, conn)
tlsConn, err := tlsConnBuilder(conn)
if err != nil {
conn.Close()
return nil, nil, err
Expand Down
22 changes: 14 additions & 8 deletions tlsutil/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type ALPNWrapper func(dc, nodeName, alpnProto string, conn net.Conn) (net.Conn,
// DCWrapper is a function that is used to wrap a non-TLS connection
// and returns an appropriate TLS connection or error. This takes
// a datacenter as an argument.
type DCWrapper func(dc string, conn net.Conn) (net.Conn, error)
type DCWrapper func(dc string, conn net.Conn) (func(conn net.Conn) (net.Conn, error), bool)

// Wrapper is a variant of DCWrapper, where the DC is provided as
// a constant value. This is usually done by currying DCWrapper.
Expand Down Expand Up @@ -153,7 +153,8 @@ func SpecificDC(dc string, tlsWrap DCWrapper) Wrapper {
return nil
}
return func(conn net.Conn) (net.Conn, error) {
return tlsWrap(dc, conn)
wrap, _ := tlsWrap(dc, conn)
return wrap(conn)
}
}

Expand Down Expand Up @@ -759,18 +760,23 @@ func (c *Configurator) OutgoingALPNRPCConfig() *tls.Config {
return config
}

//type DCWrapper func(dc string, conn net.Conn) (func(conn net.Conn) (net.Conn, error), bool)

// OutgoingRPCWrapper wraps the result of OutgoingRPCConfig in a DCWrapper. It
// decides if verify server hostname should be used.
func (c *Configurator) OutgoingRPCWrapper() DCWrapper {
c.log("OutgoingRPCWrapper")

// Generate the wrapper based on dc
return func(dc string, conn net.Conn) (net.Conn, error) {
if c.UseTLS(dc) {
return c.wrapTLSClient(dc, conn)
fn := func(dc string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool) {
useTLS := c.UseTLS(dc)
wrapper := func(conn net.Conn) (net.Conn, error) {
if useTLS {
return c.wrapTLSClient(dc, conn)
}
return conn, nil
}
return conn, nil
return wrapper, useTLS
}
return fn
}

func (c *Configurator) UseTLS(dc string) bool {
Expand Down
39 changes: 30 additions & 9 deletions tlsutil/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ func TestConfigurator_outgoingWrapper_OK(t *testing.T) {
wrap := c.OutgoingRPCWrapper()
require.NotNil(t, wrap)

tlsClient, err := wrap("dc1", client)
tlsClientWrapper, useTLS := wrap("dc1", client)
require.True(t, useTLS)
tlsClient, err := tlsClientWrapper(client)
require.NoError(t, err)

defer tlsClient.Close()
Expand Down Expand Up @@ -116,7 +118,9 @@ func TestConfigurator_outgoingWrapper_noverify_OK(t *testing.T) {
wrap := c.OutgoingRPCWrapper()
require.NotNil(t, wrap)

tlsClient, err := wrap("dc1", client)
tlsClientWrapper, useTLS := wrap("dc1", client)
require.True(t, useTLS)
tlsClient, err := tlsClientWrapper(client)
require.NoError(t, err)

defer tlsClient.Close()
Expand Down Expand Up @@ -146,7 +150,9 @@ func TestConfigurator_outgoingWrapper_BadDC(t *testing.T) {
require.NoError(t, err)
wrap := c.OutgoingRPCWrapper()

tlsClient, err := wrap("dc2", client)
tlsClientWrapper, useTLS := wrap("dc2", client)
require.True(t, useTLS)
tlsClient, err := tlsClientWrapper(client)
require.NoError(t, err)

err = tlsClient.(*tls.Conn).Handshake()
Expand Down Expand Up @@ -176,7 +182,9 @@ func TestConfigurator_outgoingWrapper_BadCert(t *testing.T) {
require.NoError(t, err)
wrap := c.OutgoingRPCWrapper()

tlsClient, err := wrap("dc1", client)
tlsClientWrapper, useTLS := wrap("dc1", client)
require.True(t, useTLS)
tlsClient, err := tlsClientWrapper(client)
require.NoError(t, err)

err = tlsClient.(*tls.Conn).Handshake()
Expand Down Expand Up @@ -443,7 +451,11 @@ func TestConfigurator_loadKeyPair(t *testing.T) {

func TestConfig_SpecifyDC(t *testing.T) {
require.Nil(t, SpecificDC("", nil))
dcwrap := func(dc string, conn net.Conn) (net.Conn, error) { return nil, nil }
dcwrap := func(dc string, conn net.Conn) (func(net.Conn) (net.Conn, error), bool) {
return func(net.Conn) (net.Conn, error) {
return conn, nil
}, false
}
wrap := SpecificDC("", dcwrap)
require.NotNil(t, wrap)
conn, err := wrap(nil)
Expand Down Expand Up @@ -964,7 +976,9 @@ func TestConfigurator_OutgoingRPCWrapper(t *testing.T) {
wrapper := c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
conn := &net.TCPConn{}
cWrap, err := wrapper("", conn)
cWrapper, ok := wrapper("", conn)
require.False(t, ok)
cWrap, err := cWrapper(conn)
require.NoError(t, err)
require.Equal(t, conn, cWrap)

Expand All @@ -976,7 +990,9 @@ func TestConfigurator_OutgoingRPCWrapper(t *testing.T) {

wrapper = c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
cWrap, err = wrapper("", conn)
cWrapper, ok = wrapper("", conn)
require.True(t, ok)
cWrap, err = cWrapper(conn)
require.EqualError(t, err, "invalid argument")
require.NotEqual(t, conn, cWrap)
}
Expand All @@ -986,7 +1002,10 @@ func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) {
wrapper := c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
conn := &net.TCPConn{}
cWrap, err := wrapper("", conn)
cWrapper, ok := wrapper("", conn)
require.False(t, ok)
cWrap, err := cWrapper(conn)

require.NoError(t, err)
require.Equal(t, conn, cWrap)

Expand All @@ -998,7 +1017,9 @@ func TestConfigurator_OutgoingALPNRPCWrapper(t *testing.T) {

wrapper = c.OutgoingRPCWrapper()
require.NotNil(t, wrapper)
cWrap, err = wrapper("", conn)
cWrapper, ok = wrapper("", conn)
require.True(t, ok)
cWrap, err = cWrapper(conn)
require.EqualError(t, err, "invalid argument")
require.NotEqual(t, conn, cWrap)
}
Expand Down

0 comments on commit deeb577

Please sign in to comment.