Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create network layer abstraction to allow in-memory cluster traffic #8173

Merged
merged 1 commit into from
Jan 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions physical/raft/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,14 +503,8 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error {
case opts.ClusterListener == nil:
return errors.New("no cluster listener provided")
default:
// Load the base TLS config from the cluster listener.
baseTLSConfig, err := opts.ClusterListener.TLSConfig(ctx)
if err != nil {
return err
}

// Set the local address and localID in the streaming layer and the raft config.
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener.Addr(), baseTLSConfig)
streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener)
if err != nil {
return err
}
Expand Down
61 changes: 30 additions & 31 deletions physical/raft/streamlayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func GenerateTLSKey(reader io.Reader) (*TLSKey, error) {
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
// 30 years of single-active uptime ought to be enough for anybody
// 30 years ought to be enough for anybody
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
Expand Down Expand Up @@ -162,25 +162,26 @@ type raftLayer struct {
dialerFunc func(string, time.Duration) (net.Conn, error)

// TLS config
keyring *TLSKeyring
baseTLSConfig *tls.Config
keyring *TLSKeyring
clusterListener cluster.ClusterHook
}

// NewRaftLayer creates a new raftLayer object. It parses the TLS information
// from the network config.
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) {
func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) {
clusterAddr := clusterListener.Addr()
switch {
case clusterAddr == nil:
// Clustering disabled on the server, don't try to look for params
return nil, errors.New("no raft addr found")
}

layer := &raftLayer{
addr: clusterAddr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
logger: logger,
baseTLSConfig: baseTLSConfig,
addr: clusterAddr,
connCh: make(chan net.Conn),
closeCh: make(chan struct{}),
logger: logger,
clusterListener: clusterListener,
}

if err := layer.setTLSKeyring(raftTLSKeyring); err != nil {
Expand Down Expand Up @@ -236,6 +237,24 @@ func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error {
return nil
}

func (l *raftLayer) ServerName() string {
key := l.keyring.GetActive()
if key == nil {
return ""
}

return key.parsedCert.Subject.CommonName
}

func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate {
key := l.keyring.GetActive()
if key == nil {
return nil
}

return key.parsedCert
}

func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) {
for _, subj := range requestInfo.AcceptableCAs {
for _, key := range l.keyring.Keys {
Expand Down Expand Up @@ -346,26 +365,6 @@ func (l *raftLayer) Addr() net.Addr {

// Dial is used to create a new outgoing connection
func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {

tlsConfig := l.baseTLSConfig.Clone()

key := l.keyring.GetActive()
if key == nil {
return nil, errors.New("no active key")
}

tlsConfig.NextProtos = []string{consts.RaftStorageALPN}
tlsConfig.ServerName = key.parsedCert.Subject.CommonName

l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)

pool := x509.NewCertPool()
pool.AddCert(key.parsedCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool

dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig)
dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN)
return dialFunc(string(address), timeout)
}
45 changes: 8 additions & 37 deletions vault/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
Expand Down Expand Up @@ -302,15 +301,21 @@ func (c *Core) startClusterListener(ctx context.Context) error {

c.logger.Debug("starting cluster listeners")

c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener")))
networkLayer := c.clusterNetworkLayer

if networkLayer == nil {
networkLayer = cluster.NewTCPLayer(c.clusterListenerAddrs, c.logger.Named("cluster-listener.tcp"))
}

c.clusterListener.Store(cluster.NewListener(networkLayer, c.clusterCipherSuites, c.logger.Named("cluster-listener")))

err := c.getClusterListener().Run(ctx)
if err != nil {
return err
}
if strings.HasSuffix(c.ClusterAddr(), ":0") {
// If we listened on port 0, record the port the OS gave us.
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0]))
c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addr()))
}
return nil
}
Expand Down Expand Up @@ -355,37 +360,3 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) {
func (c *Core) SetClusterHandler(handler http.Handler) {
c.clusterHandler = handler
}

// getGRPCDialer is used to return a dialer that has the correct TLS
// configuration. Otherwise gRPC tries to be helpful and stomps all over our
// NextProtos.
func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
clusterListener := c.getClusterListener()
if clusterListener == nil {
return nil, errors.New("clustering disabled")
}

tlsConfig, err := clusterListener.TLSConfig(ctx)
if err != nil {
c.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}
if serverName != "" {
tlsConfig.ServerName = serverName
}
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}
c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName)

tlsConfig.NextProtos = []string{alpnProto}
dialer := &net.Dialer{
Timeout: timeout,
}
return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
}
}
117 changes: 87 additions & 30 deletions vault/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
Expand All @@ -27,6 +28,8 @@ const (
// Client is used to lookup a client certificate.
type Client interface {
ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error)
ServerName() string
CACert(ctx context.Context) *x509.Certificate
}

// Handler exposes functions for looking up TLS configuration and handing
Expand All @@ -48,6 +51,7 @@ type ClusterHook interface {
StopHandler(alpn string)
TLSConfig(ctx context.Context) (*tls.Config, error)
Addr() net.Addr
GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error)
}

// Listener is the source of truth for cluster handlers and connection
Expand All @@ -60,13 +64,13 @@ type Listener struct {
shutdownWg *sync.WaitGroup
server *http2.Server

listenerAddrs []*net.TCPAddr
cipherSuites []uint16
logger log.Logger
l sync.RWMutex
networkLayer NetworkLayer
cipherSuites []uint16
logger log.Logger
l sync.RWMutex
}

func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener {
func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger) *Listener {
// Create the HTTP/2 server that will be shared by both RPC and regular
// duties. Doing it this way instead of listening via the server and gRPC
// allows us to re-use the same port via ALPN. We can just tell the server
Expand All @@ -84,19 +88,22 @@ func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger)
shutdownWg: &sync.WaitGroup{},
server: h2Server,

listenerAddrs: addrs,
cipherSuites: cipherSuites,
logger: logger,
networkLayer: networkLayer,
cipherSuites: cipherSuites,
logger: logger,
}
}

// TODO: This probably isn't correct
func (cl *Listener) Addr() net.Addr {
return cl.listenerAddrs[0]
addrs := cl.Addrs()
if len(addrs) == 0 {
return nil
}
return addrs[0]
}

func (cl *Listener) Addrs() []*net.TCPAddr {
return cl.listenerAddrs
func (cl *Listener) Addrs() []net.Addr {
return cl.networkLayer.Addrs()
}

// AddClient adds a new client for an ALPN name
Expand Down Expand Up @@ -236,29 +243,15 @@ func (cl *Listener) Run(ctx context.Context) error {
// The server supports all of the possible protos
tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN}

for i, laddr := range cl.listenerAddrs {
for _, ln := range cl.networkLayer.Listeners() {
// closeCh is used to shutdown the spawned goroutines once this
// function returns
closeCh := make(chan struct{})

if cl.logger.IsInfo() {
cl.logger.Info("starting listener", "listener_address", laddr)
}

// Create a TCP listener. We do this separately and specifically
// with TCP so that we can set deadlines.
tcpLn, err := net.ListenTCP("tcp", laddr)
if err != nil {
cl.logger.Error("error starting listener", "error", err)
continue
}
if laddr.String() != tcpLn.Addr().String() {
// If we listened on port 0, record the port the OS gave us.
cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr)
}
localLn := ln

// Wrap the listener with TLS
tlsLn := tls.NewListener(tcpLn, tlsConfig)
tlsLn := tls.NewListener(localLn, tlsConfig)

if cl.logger.IsInfo() {
cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr())
Expand All @@ -281,7 +274,7 @@ func (cl *Listener) Run(ctx context.Context) error {
// Set the deadline for the accept call. If it passes we'll get
// an error, causing us to check the condition at the top
// again.
tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))
localLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline))

// Accept the connection
conn, err := tlsLn.Accept()
Expand Down Expand Up @@ -365,3 +358,67 @@ func (cl *Listener) Stop() {
cl.shutdownWg.Wait()
cl.logger.Info("rpc listeners successfully shut down")
}

// GetDialerFunc returns a function that looks up the TLS information for the
// provided alpn name and calls the network layer's dial function.
func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string, time.Duration) (net.Conn, error) {
return func(addr string, timeout time.Duration) (net.Conn, error) {
tlsConfig, err := cl.TLSConfig(ctx)
if err != nil {
cl.logger.Error("failed to get tls configuration", "error", err)
return nil, err
}

if tlsConfig == nil {
return nil, errors.New("no tls config found")
}

cl.l.RLock()
client, ok := cl.clients[alpn]
cl.l.RUnlock()
if !ok {
return nil, fmt.Errorf("no client configured for alpn: %q", alpn)
}

serverName := client.ServerName()
if serverName != "" {
tlsConfig.ServerName = serverName
}

caCert := client.CACert(ctx)
if caCert != nil {
pool := x509.NewCertPool()
pool.AddCert(caCert)
tlsConfig.RootCAs = pool
tlsConfig.ClientCAs = pool
}

tlsConfig.NextProtos = []string{alpn}
cl.logger.Debug("creating rpc dialer", "alpn", alpn, "host", tlsConfig.ServerName)

return cl.networkLayer.Dial(addr, timeout, tlsConfig)
}
}

// NetworkListener is used by the network layer to define a net.Listener for use
// in the cluster listener.
type NetworkListener interface {
net.Listener

SetDeadline(t time.Time) error
}

// NetworkLayer is the network abstraction used in the cluster listener.
// Abstracting the network layer out allows us to swap the underlying
// implementations for tests.
type NetworkLayer interface {
Addrs() []net.Addr
Listeners() []NetworkListener
Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error)
Close() error
}

// NetworkLayerSet is used for returning a slice of layers to a caller.
type NetworkLayerSet interface {
Layers() []NetworkLayer
}
Loading