From 0b2977e59f95915df590509121be4e4a169a0482 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 16 Jan 2020 18:00:59 -0800 Subject: [PATCH] Create network layer abstraction to allow in-memory cluster traffic --- physical/raft/raft.go | 8 +- physical/raft/streamlayer.go | 61 +++--- vault/cluster.go | 45 +---- vault/cluster/cluster.go | 117 ++++++++--- vault/cluster/inmem_layer.go | 323 ++++++++++++++++++++++++++++++ vault/cluster/inmem_layer_test.go | 240 ++++++++++++++++++++++ vault/cluster/tcp_layer.go | 111 ++++++++++ vault/cluster_test.go | 28 ++- vault/core.go | 12 ++ vault/request_forwarding.go | 26 ++- vault/testing.go | 17 +- 11 files changed, 870 insertions(+), 118 deletions(-) create mode 100644 vault/cluster/inmem_layer.go create mode 100644 vault/cluster/inmem_layer_test.go create mode 100644 vault/cluster/tcp_layer.go diff --git a/physical/raft/raft.go b/physical/raft/raft.go index 97b8566a2450..88c1c7b11f35 100644 --- a/physical/raft/raft.go +++ b/physical/raft/raft.go @@ -503,14 +503,8 @@ func (b *RaftBackend) SetupCluster(ctx context.Context, opts SetupOpts) error { case opts.ClusterListener == nil: return errors.New("no cluster listener provided") default: - // Load the base TLS config from the cluster listener. - baseTLSConfig, err := opts.ClusterListener.TLSConfig(ctx) - if err != nil { - return err - } - // Set the local address and localID in the streaming layer and the raft config. - streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener.Addr(), baseTLSConfig) + streamLayer, err := NewRaftLayer(b.logger.Named("stream"), opts.TLSKeyring, opts.ClusterListener) if err != nil { return err } diff --git a/physical/raft/streamlayer.go b/physical/raft/streamlayer.go index e1fcedbda4b4..fcf0a0be57f8 100644 --- a/physical/raft/streamlayer.go +++ b/physical/raft/streamlayer.go @@ -110,7 +110,7 @@ func GenerateTLSKey(reader io.Reader) (*TLSKey, error) { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign, SerialNumber: big.NewInt(mathrand.Int63()), NotBefore: time.Now().Add(-30 * time.Second), - // 30 years of single-active uptime ought to be enough for anybody + // 30 years ought to be enough for anybody NotAfter: time.Now().Add(262980 * time.Hour), BasicConstraintsValid: true, IsCA: true, @@ -162,13 +162,14 @@ type raftLayer struct { dialerFunc func(string, time.Duration) (net.Conn, error) // TLS config - keyring *TLSKeyring - baseTLSConfig *tls.Config + keyring *TLSKeyring + clusterListener cluster.ClusterHook } // NewRaftLayer creates a new raftLayer object. It parses the TLS information // from the network config. -func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net.Addr, baseTLSConfig *tls.Config) (*raftLayer, error) { +func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterListener cluster.ClusterHook) (*raftLayer, error) { + clusterAddr := clusterListener.Addr() switch { case clusterAddr == nil: // Clustering disabled on the server, don't try to look for params @@ -176,11 +177,11 @@ func NewRaftLayer(logger log.Logger, raftTLSKeyring *TLSKeyring, clusterAddr net } layer := &raftLayer{ - addr: clusterAddr, - connCh: make(chan net.Conn), - closeCh: make(chan struct{}), - logger: logger, - baseTLSConfig: baseTLSConfig, + addr: clusterAddr, + connCh: make(chan net.Conn), + closeCh: make(chan struct{}), + logger: logger, + clusterListener: clusterListener, } if err := layer.setTLSKeyring(raftTLSKeyring); err != nil { @@ -236,6 +237,24 @@ func (l *raftLayer) setTLSKeyring(keyring *TLSKeyring) error { return nil } +func (l *raftLayer) ServerName() string { + key := l.keyring.GetActive() + if key == nil { + return "" + } + + return key.parsedCert.Subject.CommonName +} + +func (l *raftLayer) CACert(ctx context.Context) *x509.Certificate { + key := l.keyring.GetActive() + if key == nil { + return nil + } + + return key.parsedCert +} + func (l *raftLayer) ClientLookup(ctx context.Context, requestInfo *tls.CertificateRequestInfo) (*tls.Certificate, error) { for _, subj := range requestInfo.AcceptableCAs { for _, key := range l.keyring.Keys { @@ -346,26 +365,6 @@ func (l *raftLayer) Addr() net.Addr { // Dial is used to create a new outgoing connection func (l *raftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { - - tlsConfig := l.baseTLSConfig.Clone() - - key := l.keyring.GetActive() - if key == nil { - return nil, errors.New("no active key") - } - - tlsConfig.NextProtos = []string{consts.RaftStorageALPN} - tlsConfig.ServerName = key.parsedCert.Subject.CommonName - - l.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName) - - pool := x509.NewCertPool() - pool.AddCert(key.parsedCert) - tlsConfig.RootCAs = pool - tlsConfig.ClientCAs = pool - - dialer := &net.Dialer{ - Timeout: timeout, - } - return tls.DialWithDialer(dialer, "tcp", string(address), tlsConfig) + dialFunc := l.clusterListener.GetDialerFunc(context.Background(), consts.RaftStorageALPN) + return dialFunc(string(address), timeout) } diff --git a/vault/cluster.go b/vault/cluster.go index 9ba85f6a091a..e9bc2bb10eda 100644 --- a/vault/cluster.go +++ b/vault/cluster.go @@ -5,7 +5,6 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/json" @@ -302,7 +301,13 @@ func (c *Core) startClusterListener(ctx context.Context) error { c.logger.Debug("starting cluster listeners") - c.clusterListener.Store(cluster.NewListener(c.clusterListenerAddrs, c.clusterCipherSuites, c.logger.Named("cluster-listener"))) + networkLayer := c.clusterNetworkLayer + + if networkLayer == nil { + networkLayer = cluster.NewTCPLayer(c.clusterListenerAddrs, c.logger.Named("cluster-listener.tcp")) + } + + c.clusterListener.Store(cluster.NewListener(networkLayer, c.clusterCipherSuites, c.logger.Named("cluster-listener"))) err := c.getClusterListener().Run(ctx) if err != nil { @@ -310,7 +315,7 @@ func (c *Core) startClusterListener(ctx context.Context) error { } if strings.HasSuffix(c.ClusterAddr(), ":0") { // If we listened on port 0, record the port the OS gave us. - c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addrs()[0])) + c.clusterAddr.Store(fmt.Sprintf("https://%s", c.getClusterListener().Addr())) } return nil } @@ -355,37 +360,3 @@ func (c *Core) SetClusterListenerAddrs(addrs []*net.TCPAddr) { func (c *Core) SetClusterHandler(handler http.Handler) { c.clusterHandler = handler } - -// getGRPCDialer is used to return a dialer that has the correct TLS -// configuration. Otherwise gRPC tries to be helpful and stomps all over our -// NextProtos. -func (c *Core) getGRPCDialer(ctx context.Context, alpnProto, serverName string, caCert *x509.Certificate) func(string, time.Duration) (net.Conn, error) { - return func(addr string, timeout time.Duration) (net.Conn, error) { - clusterListener := c.getClusterListener() - if clusterListener == nil { - return nil, errors.New("clustering disabled") - } - - tlsConfig, err := clusterListener.TLSConfig(ctx) - if err != nil { - c.logger.Error("failed to get tls configuration", "error", err) - return nil, err - } - if serverName != "" { - tlsConfig.ServerName = serverName - } - if caCert != nil { - pool := x509.NewCertPool() - pool.AddCert(caCert) - tlsConfig.RootCAs = pool - tlsConfig.ClientCAs = pool - } - c.logger.Debug("creating rpc dialer", "host", tlsConfig.ServerName) - - tlsConfig.NextProtos = []string{alpnProto} - dialer := &net.Dialer{ - Timeout: timeout, - } - return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) - } -} diff --git a/vault/cluster/cluster.go b/vault/cluster/cluster.go index 2e6de3dfbb05..8b7ed14efc64 100644 --- a/vault/cluster/cluster.go +++ b/vault/cluster/cluster.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "net" "sync" "sync/atomic" @@ -27,6 +28,8 @@ const ( // Client is used to lookup a client certificate. type Client interface { ClientLookup(context.Context, *tls.CertificateRequestInfo) (*tls.Certificate, error) + ServerName() string + CACert(ctx context.Context) *x509.Certificate } // Handler exposes functions for looking up TLS configuration and handing @@ -48,6 +51,7 @@ type ClusterHook interface { StopHandler(alpn string) TLSConfig(ctx context.Context) (*tls.Config, error) Addr() net.Addr + GetDialerFunc(ctx context.Context, alpnProto string) func(string, time.Duration) (net.Conn, error) } // Listener is the source of truth for cluster handlers and connection @@ -60,13 +64,13 @@ type Listener struct { shutdownWg *sync.WaitGroup server *http2.Server - listenerAddrs []*net.TCPAddr - cipherSuites []uint16 - logger log.Logger - l sync.RWMutex + networkLayer NetworkLayer + cipherSuites []uint16 + logger log.Logger + l sync.RWMutex } -func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) *Listener { +func NewListener(networkLayer NetworkLayer, cipherSuites []uint16, logger log.Logger) *Listener { // Create the HTTP/2 server that will be shared by both RPC and regular // duties. Doing it this way instead of listening via the server and gRPC // allows us to re-use the same port via ALPN. We can just tell the server @@ -84,19 +88,22 @@ func NewListener(addrs []*net.TCPAddr, cipherSuites []uint16, logger log.Logger) shutdownWg: &sync.WaitGroup{}, server: h2Server, - listenerAddrs: addrs, - cipherSuites: cipherSuites, - logger: logger, + networkLayer: networkLayer, + cipherSuites: cipherSuites, + logger: logger, } } -// TODO: This probably isn't correct func (cl *Listener) Addr() net.Addr { - return cl.listenerAddrs[0] + addrs := cl.Addrs() + if len(addrs) == 0 { + return nil + } + return addrs[0] } -func (cl *Listener) Addrs() []*net.TCPAddr { - return cl.listenerAddrs +func (cl *Listener) Addrs() []net.Addr { + return cl.networkLayer.Addrs() } // AddClient adds a new client for an ALPN name @@ -236,29 +243,15 @@ func (cl *Listener) Run(ctx context.Context) error { // The server supports all of the possible protos tlsConfig.NextProtos = []string{"h2", consts.RequestForwardingALPN, consts.PerfStandbyALPN, consts.PerformanceReplicationALPN, consts.DRReplicationALPN} - for i, laddr := range cl.listenerAddrs { + for _, ln := range cl.networkLayer.Listeners() { // closeCh is used to shutdown the spawned goroutines once this // function returns closeCh := make(chan struct{}) - if cl.logger.IsInfo() { - cl.logger.Info("starting listener", "listener_address", laddr) - } - - // Create a TCP listener. We do this separately and specifically - // with TCP so that we can set deadlines. - tcpLn, err := net.ListenTCP("tcp", laddr) - if err != nil { - cl.logger.Error("error starting listener", "error", err) - continue - } - if laddr.String() != tcpLn.Addr().String() { - // If we listened on port 0, record the port the OS gave us. - cl.listenerAddrs[i] = tcpLn.Addr().(*net.TCPAddr) - } + localLn := ln // Wrap the listener with TLS - tlsLn := tls.NewListener(tcpLn, tlsConfig) + tlsLn := tls.NewListener(localLn, tlsConfig) if cl.logger.IsInfo() { cl.logger.Info("serving cluster requests", "cluster_listen_address", tlsLn.Addr()) @@ -281,7 +274,7 @@ func (cl *Listener) Run(ctx context.Context) error { // Set the deadline for the accept call. If it passes we'll get // an error, causing us to check the condition at the top // again. - tcpLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline)) + localLn.SetDeadline(time.Now().Add(ListenerAcceptDeadline)) // Accept the connection conn, err := tlsLn.Accept() @@ -365,3 +358,67 @@ func (cl *Listener) Stop() { cl.shutdownWg.Wait() cl.logger.Info("rpc listeners successfully shut down") } + +// GetDialerFunc returns a function that looks up the TLS information for the +// provided alpn name and calls the network layer's dial function. +func (cl *Listener) GetDialerFunc(ctx context.Context, alpn string) func(string, time.Duration) (net.Conn, error) { + return func(addr string, timeout time.Duration) (net.Conn, error) { + tlsConfig, err := cl.TLSConfig(ctx) + if err != nil { + cl.logger.Error("failed to get tls configuration", "error", err) + return nil, err + } + + if tlsConfig == nil { + return nil, errors.New("no tls config found") + } + + cl.l.RLock() + client, ok := cl.clients[alpn] + cl.l.RUnlock() + if !ok { + return nil, fmt.Errorf("no client configured for alpn: %q", alpn) + } + + serverName := client.ServerName() + if serverName != "" { + tlsConfig.ServerName = serverName + } + + caCert := client.CACert(ctx) + if caCert != nil { + pool := x509.NewCertPool() + pool.AddCert(caCert) + tlsConfig.RootCAs = pool + tlsConfig.ClientCAs = pool + } + + tlsConfig.NextProtos = []string{alpn} + cl.logger.Debug("creating rpc dialer", "alpn", alpn, "host", tlsConfig.ServerName) + + return cl.networkLayer.Dial(addr, timeout, tlsConfig) + } +} + +// NetworkListener is used by the network layer to define a net.Listener for use +// in the cluster listener. +type NetworkListener interface { + net.Listener + + SetDeadline(t time.Time) error +} + +// NetworkLayer is the network abstraction used in the cluster listener. +// Abstracting the network layer out allows us to swap the underlying +// implementations for tests. +type NetworkLayer interface { + Addrs() []net.Addr + Listeners() []NetworkListener + Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) + Close() error +} + +// NetworkLayerSet is used for returning a slice of layers to a caller. +type NetworkLayerSet interface { + Layers() []NetworkLayer +} diff --git a/vault/cluster/inmem_layer.go b/vault/cluster/inmem_layer.go new file mode 100644 index 000000000000..0a00981e2a9e --- /dev/null +++ b/vault/cluster/inmem_layer.go @@ -0,0 +1,323 @@ +package cluster + +import ( + "crypto/tls" + "errors" + "net" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/sdk/helper/base62" + "go.uber.org/atomic" +) + +// InmemLayer is an in-memory implementation of NetworkLayer. This is +// primarially useful for tests. +type InmemLayer struct { + listener *inmemListener + addr string + logger log.Logger + + servConns map[string][]net.Conn + clientConns map[string][]net.Conn + + peers map[string]*InmemLayer + l sync.Mutex + + stopped *atomic.Bool + stopCh chan struct{} +} + +// NewInmemLayer returns a new in-memory layer configured to listen on the +// provided address. +func NewInmemLayer(addr string, logger log.Logger) *InmemLayer { + return &InmemLayer{ + addr: addr, + logger: logger, + stopped: atomic.NewBool(false), + stopCh: make(chan struct{}), + peers: make(map[string]*InmemLayer), + servConns: make(map[string][]net.Conn), + clientConns: make(map[string][]net.Conn), + } +} + +// Addrs implements NetworkLayer. +func (l *InmemLayer) Addrs() []net.Addr { + l.l.Lock() + defer l.l.Unlock() + + if l.listener == nil { + return nil + } + + return []net.Addr{l.listener.Addr()} +} + +// Listeners implements NetworkLayer. +func (l *InmemLayer) Listeners() []NetworkListener { + l.l.Lock() + defer l.l.Unlock() + + if l.listener != nil { + return []NetworkListener{l.listener} + } + + l.listener = &inmemListener{ + addr: l.addr, + pendingConns: make(chan net.Conn), + + stopped: atomic.NewBool(false), + stopCh: make(chan struct{}), + } + + return []NetworkListener{l.listener} +} + +// Dial implements NetworkLayer. +func (l *InmemLayer) Dial(addr string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) { + l.l.Lock() + defer l.l.Unlock() + + peer, ok := l.peers[addr] + if !ok { + return nil, errors.New("inmemlayer: no address found") + } + + conn, err := peer.clientConn(l.addr) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(conn, tlsConfig) + + l.clientConns[addr] = append(l.clientConns[addr], tlsConn) + + return tlsConn, nil +} + +// clientConn is executed on a server when a new client connection comes in and +// needs to be Accepted. +func (l *InmemLayer) clientConn(addr string) (net.Conn, error) { + l.l.Lock() + defer l.l.Unlock() + + if l.listener == nil { + return nil, errors.New("inmemlayer: listener not started") + } + + _, ok := l.peers[addr] + if !ok { + return nil, errors.New("inmemlayer: no peer found") + } + + retConn, servConn := net.Pipe() + + l.servConns[addr] = append(l.servConns[addr], servConn) + + select { + case l.listener.pendingConns <- servConn: + case <-time.After(2 * time.Second): + return nil, errors.New("inmemlayer: timeout while accepting connection") + } + + return retConn, nil +} + +// Connect is used to connect this transport to another transport for +// a given peer name. This allows for local routing. +func (l *InmemLayer) Connect(peer string, remote *InmemLayer) { + l.l.Lock() + defer l.l.Unlock() + l.peers[peer] = remote +} + +// Disconnect is used to remove the ability to route to a given peer. +func (l *InmemLayer) Disconnect(peer string) { + l.l.Lock() + defer l.l.Unlock() + delete(l.peers, peer) + + // Remove any open connections + servConns := l.servConns[peer] + for _, c := range servConns { + c.Close() + } + delete(l.servConns, peer) + + clientConns := l.clientConns[peer] + for _, c := range clientConns { + c.Close() + } + delete(l.clientConns, peer) +} + +// DisconnectAll is used to remove all routes to peers. +func (l *InmemLayer) DisconnectAll() { + l.l.Lock() + defer l.l.Unlock() + l.peers = make(map[string]*InmemLayer) + + // Close all connections + for _, peerConns := range l.servConns { + for _, c := range peerConns { + c.Close() + } + } + l.servConns = make(map[string][]net.Conn) + + for _, peerConns := range l.clientConns { + for _, c := range peerConns { + c.Close() + } + } + l.clientConns = make(map[string][]net.Conn) +} + +// Close is used to permanently disable the transport +func (l *InmemLayer) Close() error { + if l.stopped.Swap(true) { + return nil + } + + l.DisconnectAll() + close(l.stopCh) + return nil +} + +// inmemListener implements the NetworkListener interface. +type inmemListener struct { + addr string + pendingConns chan net.Conn + + stopped *atomic.Bool + stopCh chan struct{} + + deadline time.Time +} + +// Accept implements the NetworkListener interface. +func (ln *inmemListener) Accept() (net.Conn, error) { + deadline := ln.deadline + if !deadline.IsZero() { + select { + case conn := <-ln.pendingConns: + return conn, nil + case <-time.After(time.Until(deadline)): + return nil, deadlineError("deadline") + case <-ln.stopCh: + return nil, errors.New("listener shut down") + } + } + + select { + case conn := <-ln.pendingConns: + return conn, nil + case <-ln.stopCh: + return nil, errors.New("listener shut down") + } +} + +// Close implements the NetworkListener interface. +func (ln *inmemListener) Close() error { + if ln.stopped.Swap(true) { + return nil + } + + close(ln.stopCh) + return nil +} + +// Addr implements the NetworkListener interface. +func (ln *inmemListener) Addr() net.Addr { + return inmemAddr{addr: ln.addr} +} + +// SetDeadline implements the NetworkListener interface. +func (ln *inmemListener) SetDeadline(d time.Time) error { + ln.deadline = d + return nil +} + +type inmemAddr struct { + addr string +} + +func (a inmemAddr) Network() string { + return "inmem" +} +func (a inmemAddr) String() string { + return a.addr +} + +type deadlineError string + +func (d deadlineError) Error() string { return string(d) } +func (d deadlineError) Timeout() bool { return true } +func (d deadlineError) Temporary() bool { return true } + +// InmemLayerCluster composes a set of layers and handles connecting them all +// together. It also satisfies the NetworkLayerSet interface. +type InmemLayerCluster struct { + layers []*InmemLayer +} + +// NewInmemLayerCluster returns a new in-memory layer set that builds n nodes +// and connects them all together. +func NewInmemLayerCluster(nodes int, logger log.Logger) (*InmemLayerCluster, error) { + clusterID, err := base62.Random(4) + if err != nil { + return nil, err + } + + clusterName := "cluster_" + clusterID + + var layers []*InmemLayer + for i := 0; i < nodes; i++ { + nodeID, err := base62.Random(4) + if err != nil { + return nil, err + } + + nodeName := clusterName + "_node_" + nodeID + + layers = append(layers, NewInmemLayer(nodeName, logger)) + } + + // Connect all the peers together + for _, node := range layers { + for _, peer := range layers { + // Don't connect to itself + if node == peer { + continue + } + + node.Connect(peer.addr, peer) + peer.Connect(node.addr, node) + } + } + + return &InmemLayerCluster{layers: layers}, nil +} + +// ConnectCluster connects this cluster with the provided remote cluster, +// connecting all nodes to each other. +func (ic *InmemLayerCluster) ConnectCluster(remote *InmemLayerCluster) { + for _, node := range ic.layers { + for _, peer := range remote.layers { + node.Connect(peer.addr, peer) + peer.Connect(node.addr, node) + } + } +} + +// Layers implements the NetworkLayerSet interface. +func (ic *InmemLayerCluster) Layers() []NetworkLayer { + ret := make([]NetworkLayer, len(ic.layers)) + for i, l := range ic.layers { + ret[i] = l + } + + return ret +} diff --git a/vault/cluster/inmem_layer_test.go b/vault/cluster/inmem_layer_test.go new file mode 100644 index 000000000000..35d7e44eb7bb --- /dev/null +++ b/vault/cluster/inmem_layer_test.go @@ -0,0 +1,240 @@ +package cluster + +import ( + "sync" + "testing" + "time" + + "go.uber.org/atomic" +) + +func TestInmemCluster_Connect(t *testing.T) { + cluster, err := NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + + server := cluster.layers[0] + + listener := server.Listeners()[0] + var accepted int + stopCh := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + } + + listener.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err := listener.Accept() + if err != nil { + return + } + + accepted++ + + } + }() + + // Make sure two nodes can connect in + conn, err := cluster.layers[1].Dial(server.addr, 0, nil) + if err != nil { + t.Fatal(err) + } + + if conn == nil { + t.Fatal("nil conn") + } + + conn, err = cluster.layers[2].Dial(server.addr, 0, nil) + if err != nil { + t.Fatal(err) + } + + if conn == nil { + t.Fatal("nil conn") + } + + close(stopCh) + wg.Wait() + + if accepted != 2 { + t.Fatalf("expected 2 connections to be accepted, got %d", accepted) + } +} + +func TestInmemCluster_Disconnect(t *testing.T) { + cluster, err := NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + + server := cluster.layers[0] + server.Disconnect(cluster.layers[1].addr) + + listener := server.Listeners()[0] + var accepted int + stopCh := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + } + + listener.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err := listener.Accept() + if err != nil { + return + } + + accepted++ + + } + }() + + // Make sure node1 cannot connect in + conn, err := cluster.layers[1].Dial(server.addr, 0, nil) + if err == nil { + t.Fatal("expected error") + } + + if conn != nil { + t.Fatal("expected nil conn") + } + + // Node2 should be able to connect + conn, err = cluster.layers[2].Dial(server.addr, 0, nil) + if err != nil { + t.Fatal(err) + } + + if conn == nil { + t.Fatal("nil conn") + } + + close(stopCh) + wg.Wait() + + if accepted != 1 { + t.Fatalf("expected 1 connections to be accepted, got %d", accepted) + } +} + +func TestInmemCluster_DisconnectAll(t *testing.T) { + cluster, err := NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + + server := cluster.layers[0] + server.DisconnectAll() + + // Make sure nodes cannot connect in + conn, err := cluster.layers[1].Dial(server.addr, 0, nil) + if err == nil { + t.Fatal("expected error") + } + + if conn != nil { + t.Fatal("expected nil conn") + } + + conn, err = cluster.layers[2].Dial(server.addr, 0, nil) + if err == nil { + t.Fatal("expected error") + } + + if conn != nil { + t.Fatal("expected nil conn") + } +} + +func TestInmemCluster_ConnectCluster(t *testing.T) { + cluster, err := NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + cluster2, err := NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + + cluster.ConnectCluster(cluster2) + + var accepted atomic.Int32 + stopCh := make(chan struct{}) + var wg sync.WaitGroup + acceptConns := func(listener NetworkListener) { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stopCh: + return + default: + } + + listener.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err := listener.Accept() + if err != nil { + return + } + + accepted.Add(1) + + } + }() + } + + // Start a listener on each node. + for _, node := range cluster.layers { + acceptConns(node.Listeners()[0]) + } + for _, node := range cluster2.layers { + acceptConns(node.Listeners()[0]) + } + + // Make sure each node can connect to each other + for _, node1 := range cluster.layers { + for _, node2 := range cluster2.layers { + conn, err := node1.Dial(node2.addr, 0, nil) + if err != nil { + t.Fatal(err) + } + + if conn == nil { + t.Fatal("nil conn") + } + + conn, err = node2.Dial(node1.addr, 0, nil) + if err != nil { + t.Fatal(err) + } + + if conn == nil { + t.Fatal("nil conn") + } + } + } + + close(stopCh) + wg.Wait() + + if accepted.Load() != 18 { + t.Fatalf("expected 18 connections to be accepted, got %d", accepted) + } +} diff --git a/vault/cluster/tcp_layer.go b/vault/cluster/tcp_layer.go new file mode 100644 index 000000000000..a0092d8b81fc --- /dev/null +++ b/vault/cluster/tcp_layer.go @@ -0,0 +1,111 @@ +package cluster + +import ( + "crypto/tls" + "net" + "sync" + "time" + + log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-multierror" + "go.uber.org/atomic" +) + +// TCPLayer implements the NetworkLayer interface and uses TCP as the underlying +// network. +type TCPLayer struct { + listeners []NetworkListener + addrs []*net.TCPAddr + logger log.Logger + + l sync.Mutex + stopped *atomic.Bool +} + +// NewTCPLayer returns a TCPLayer. +func NewTCPLayer(addrs []*net.TCPAddr, logger log.Logger) *TCPLayer { + return &TCPLayer{ + addrs: addrs, + logger: logger, + stopped: atomic.NewBool(false), + } +} + +// Addrs implements NetworkLayer. +func (l *TCPLayer) Addrs() []net.Addr { + l.l.Lock() + defer l.l.Unlock() + + if len(l.addrs) == 0 { + return nil + } + + ret := make([]net.Addr, len(l.addrs)) + for i, a := range l.addrs { + ret[i] = a + } + + return ret +} + +// Listeners implements NetworkLayer. It starts a new TCP listener for each +// configured address. +func (l *TCPLayer) Listeners() []NetworkListener { + l.l.Lock() + defer l.l.Unlock() + + if l.listeners != nil { + return l.listeners + } + + listeners := make([]NetworkListener, len(l.addrs)) + for i, laddr := range l.addrs { + if l.logger.IsInfo() { + l.logger.Info("starting listener", "listener_address", laddr) + } + + tcpLn, err := net.ListenTCP("tcp", laddr) + if err != nil { + l.logger.Error("error starting listener", "error", err) + continue + } + if laddr.String() != tcpLn.Addr().String() { + // If we listened on port 0, record the port the OS gave us. + l.addrs[i] = tcpLn.Addr().(*net.TCPAddr) + } + + listeners[i] = tcpLn + } + + l.listeners = listeners + + return listeners +} + +// Dial implements the NetworkLayer interface. +func (l *TCPLayer) Dial(address string, timeout time.Duration, tlsConfig *tls.Config) (*tls.Conn, error) { + dialer := &net.Dialer{ + Timeout: timeout, + } + return tls.DialWithDialer(dialer, "tcp", address, tlsConfig) +} + +// Close implements the NetworkLayer interface. +func (l *TCPLayer) Close() error { + if l.stopped.Swap(true) { + return nil + } + l.l.Lock() + defer l.l.Unlock() + + var retErr *multierror.Error + for _, ln := range l.listeners { + if err := ln.Close(); err != nil { + retErr = multierror.Append(retErr, err) + } + } + + l.listeners = nil + + return retErr.ErrorOrNil() +} diff --git a/vault/cluster_test.go b/vault/cluster_test.go index d148e5f786ff..f6f9693f3ae6 100644 --- a/vault/cluster_test.go +++ b/vault/cluster_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" "net/http" "testing" "time" @@ -15,6 +14,7 @@ import ( "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/physical" "github.com/hashicorp/vault/sdk/physical/inmem" + "github.com/hashicorp/vault/vault/cluster" ) var ( @@ -100,13 +100,13 @@ func TestCluster_ListenForRequests(t *testing.T) { // Wait for core to become active TestWaitActive(t, cores[0].Core) - cores[0].getClusterListener().AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) + clusterListener := cores[0].getClusterListener() + clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{cores[0].Core}) addrs := cores[0].getClusterListener().Addrs() // Use this to have a valid config after sealing since ClusterTLSConfig returns nil checkListenersFunc := func(expectFail bool) { - parsedCert := cores[0].localClusterParsedCert.Load().(*x509.Certificate) - dialer := cores[0].getGRPCDialer(context.Background(), consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert) + dialer := clusterListener.GetDialerFunc(context.Background(), consts.RequestForwardingALPN) for i := range cores[0].Listeners { clnAddr := addrs[i] @@ -172,11 +172,25 @@ func TestCluster_ForwardRequests(t *testing.T) { // Make this nicer for tests manualStepDownSleepPeriod = 5 * time.Second - testCluster_ForwardRequestsCommon(t) + t.Run("tcpLayer", func(t *testing.T) { + testCluster_ForwardRequestsCommon(t, nil) + }) + + t.Run("inmemLayer", func(t *testing.T) { + // Run again with in-memory network + inmemCluster, err := cluster.NewInmemLayerCluster(3, nil) + if err != nil { + t.Fatal(err) + } + + testCluster_ForwardRequestsCommon(t, &TestClusterOptions{ + ClusterLayers: inmemCluster, + }) + }) } -func testCluster_ForwardRequestsCommon(t *testing.T) { - cluster := NewTestCluster(t, nil, nil) +func testCluster_ForwardRequestsCommon(t *testing.T, clusterOpts *TestClusterOptions) { + cluster := NewTestCluster(t, nil, clusterOpts) cores := cluster.Cores cores[0].Handler.(*http.ServeMux).HandleFunc("/core1", func(w http.ResponseWriter, req *http.Request) { w.Header().Add("Content-Type", "application/json") diff --git a/vault/core.go b/vault/core.go index bf01c3ff018c..f8e503732872 100644 --- a/vault/core.go +++ b/vault/core.go @@ -493,6 +493,14 @@ type Core struct { secureRandomReader io.Reader recoveryMode bool + + clusterNetworkLayer cluster.NetworkLayer + + // PR1103disabled is used to test upgrade workflows: when set to true, + // the correct behaviour for namespaced cubbyholes is disabled, so we + // can test an upgrade to a version that includes the fixes from + // https://github.com/hashicorp/vault-enterprise/pull/1103 + PR1103disabled bool } // CoreConfig is used to parameterize a core @@ -576,6 +584,8 @@ type CoreConfig struct { CounterSyncInterval time.Duration RecoveryMode bool + + ClusterNetworkLayer cluster.NetworkLayer } func (c *CoreConfig) Clone() *CoreConfig { @@ -611,6 +621,7 @@ func (c *CoreConfig) Clone() *CoreConfig { DisableIndexing: c.DisableIndexing, AllLoggers: c.AllLoggers, CounterSyncInterval: c.CounterSyncInterval, + ClusterNetworkLayer: c.ClusterNetworkLayer, } } @@ -706,6 +717,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { maxLeaseTTL: conf.MaxLeaseTTL, cachingDisabled: conf.DisableCache, clusterName: conf.ClusterName, + clusterNetworkLayer: conf.ClusterNetworkLayer, clusterPeerClusterAddrsCache: cache.New(3*cluster.HeartbeatInterval, time.Second), enableMlock: !conf.DisableMlock, rawEnabled: conf.EnableRaw, diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index b97263e6ded4..e7961ee938cc 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -99,6 +99,19 @@ func (c *requestForwardingClusterClient) ClientLookup(ctx context.Context, reque return nil, nil } +func (c *requestForwardingClusterClient) ServerName() string { + parsedCert := c.core.localClusterParsedCert.Load().(*x509.Certificate) + if parsedCert == nil { + return "" + } + + return parsedCert.Subject.CommonName +} + +func (c *requestForwardingClusterClient) CACert(ctx context.Context) *x509.Certificate { + return c.core.localClusterParsedCert.Load().(*x509.Certificate) +} + // ServerLookup satisfies the ClusterHandler interface and returns the server's // tls certs. func (rf *requestForwardingHandler) ServerLookup(ctx context.Context, clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -246,19 +259,22 @@ func (c *Core) refreshRequestForwardingConnection(ctx context.Context, clusterAd } clusterListener := c.getClusterListener() - if clusterListener != nil { - clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{ - core: c, - }) + if clusterListener == nil { + c.logger.Error("no cluster listener configured") + return nil } + clusterListener.AddClient(consts.RequestForwardingALPN, &requestForwardingClusterClient{ + core: c, + }) + // Set up grpc forwarding handling // It's not really insecure, but we have to dial manually to get the // ALPN header right. It's just "insecure" because GRPC isn't managing // the TLS state. dctx, cancelFunc := context.WithCancel(ctx) c.rpcClientConn, err = grpc.DialContext(dctx, clusterURL.Host, - grpc.WithDialer(c.getGRPCDialer(ctx, consts.RequestForwardingALPN, parsedCert.Subject.CommonName, parsedCert)), + grpc.WithDialer(clusterListener.GetDialerFunc(ctx, consts.RequestForwardingALPN)), grpc.WithInsecure(), // it's not, we handle it in the dialer grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 2 * cluster.HeartbeatInterval, diff --git a/vault/testing.go b/vault/testing.go index d485399942b2..2686b845a2ab 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -31,6 +31,7 @@ import ( hclog "github.com/hashicorp/go-hclog" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/vault/cluster" "github.com/hashicorp/vault/vault/seal" "github.com/mitchellh/copystructure" @@ -1057,7 +1058,11 @@ type TestClusterOptions struct { FirstCoreNumber int RequireClientAuth bool // SetupFunc is called after the cluster is started. - SetupFunc func(t testing.T, c *TestCluster) + SetupFunc func(t testing.T, c *TestCluster) + PR1103Disabled bool + + // ClusterLayers are used to override the default cluster connection layer + ClusterLayers cluster.NetworkLayerSet } var DefaultNumCores = 3 @@ -1093,6 +1098,11 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te numCores = opts.NumCores } + var disablePR1103 bool + if opts != nil && opts.PR1103Disabled { + disablePR1103 = true + } + var firstCoreNumber int if opts != nil { firstCoreNumber = opts.FirstCoreNumber @@ -1486,6 +1496,10 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te } } + if opts != nil && opts.ClusterLayers != nil { + localConfig.ClusterNetworkLayer = opts.ClusterLayers.Layers()[i] + } + switch { case localConfig.LicensingConfig != nil: if pubKey != nil { @@ -1506,6 +1520,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te t.Fatalf("err: %v", err) } c.coreNumber = firstCoreNumber + i + c.PR1103disabled = disablePR1103 cores = append(cores, c) coreConfigs = append(coreConfigs, &localConfig) if opts != nil && opts.HandlerFunc != nil {