Skip to content

Commit

Permalink
rpc: oss changes for network area connection pooling (#7735)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanshasselberg authored Apr 30, 2020
1 parent 27eb12e commit 51549bd
Show file tree
Hide file tree
Showing 18 changed files with 209 additions and 86 deletions.
2 changes: 1 addition & 1 deletion agent/consul/auto_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin
for _, ip := range ips {
addr := net.TCPAddr{IP: ip, Port: port}

if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", true, &args, &reply); err == nil {
if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", &args, &reply); err == nil {
return &reply, pkPEM, nil
} else {
c.logger.Warn("AutoEncrypt failed", "error", err)
Expand Down
4 changes: 2 additions & 2 deletions agent/consul/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat
}

// Start maintenance task for servers
c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool)
c.routers = router.New(c.logger, c.shutdownCh, c.serf, c.connPool, "")
go c.routers.Start()

// Start LAN event handlers after the router is complete since the event
Expand Down Expand Up @@ -308,7 +308,7 @@ TRY:
}

// Make the request.
rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply)
rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, args, reply)
if rpcErr == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion agent/consul/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) {
for range servers {
time.Sleep(200 * time.Millisecond)
s := c.routers.FindServer()
ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version, s.UseTLS)
ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version)
if !ok {
t.Errorf("Unable to ping server %v: %s", s.String(), err)
}
Expand Down
41 changes: 38 additions & 3 deletions agent/consul/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,42 @@ func (s *Server) handleMultiplexV2(conn net.Conn) {
}
return
}
go s.handleConsulConn(sub)

// In the beginning only RPC was supposed to be multiplexed
// with yamux. In order to add the ability to multiplex network
// area connections, this workaround was added.
// This code peeks the first byte and checks if it is
// RPCGossip, in which case this is handled by enterprise code.
// Otherwise this connection is handled like before by the RPC
// handler.
// This wouldn't work if a normal RPC could start with
// RPCGossip(6). In messagepack a 6 encodes a positive fixint:
// https://github.com/msgpack/msgpack/blob/master/spec.md.
// None of the RPCs we are doing starts with that, usually it is
// a string for datacenter.
peeked, first, err := pool.PeekFirstByte(sub)
if err != nil {
s.rpcLogger().Error("Problem peeking connection", "conn", logConn(sub), "err", err)
sub.Close()
return
}
sub = peeked
switch first {
case pool.RPCGossip:
buf := make([]byte, 1)
sub.Read(buf)
go func() {
if !s.handleEnterpriseRPCConn(pool.RPCGossip, sub, false) {
s.rpcLogger().Error("unrecognized RPC byte",
"byte", pool.RPCGossip,
"conn", logConn(conn),
)
sub.Close()
}
}()
default:
go s.handleConsulConn(sub)
}
}
}

Expand Down Expand Up @@ -517,7 +552,7 @@ CHECK_LEADER:
rpcErr := structs.ErrNoLeader
if leader != nil {
rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr,
leader.Version, method, leader.UseTLS, args, reply)
leader.Version, method, args, reply)
if rpcErr != nil && canRetry(info, rpcErr) {
goto RETRY
}
Expand Down Expand Up @@ -582,7 +617,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{

metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1,
[]metrics.Label{{Name: "datacenter", Value: dc}})
if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil {
if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, args, reply); err != nil {
manager.NotifyFailedServer(server)
s.rpcLogger().Error("RPC failed to server in DC",
"server", server.Addr,
Expand Down
20 changes: 7 additions & 13 deletions agent/consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token
loggers: loggers,
leaveCh: make(chan struct{}),
reconcileCh: make(chan serf.Member, reconcileChSize),
router: router.NewRouter(serverLogger, config.Datacenter),
router: router.NewRouter(serverLogger, config.Datacenter, fmt.Sprintf("%s.%s", config.NodeName, config.Datacenter)),
rpcServer: rpc.NewServer(),
insecureRPCServer: rpc.NewServer(),
tlsConfigurator: tlsConfigurator,
Expand Down Expand Up @@ -551,7 +551,7 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token

// Add a "static route" to the WAN Serf and hook it up to Serf events.
if s.serfWAN != nil {
if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool, s.config.VerifyOutgoing); err != nil {
if err := s.router.AddArea(types.AreaWAN, s.serfWAN, s.connPool); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to add WAN serf route: %v", err)
}
Expand Down Expand Up @@ -839,23 +839,16 @@ func (s *Server) setupRPC() error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", s.config.RPCAdvertise)
}

// TODO (hans) switch NewRaftLayer to tlsConfigurator

// Provide a DC specific wrapper. Raft replication is only
// ever done in the same datacenter, so we can provide it as a constant.
wrapper := tlsutil.SpecificDC(s.config.Datacenter, s.tlsConfigurator.OutgoingRPCWrapper())

// Define a callback for determining whether to wrap a connection with TLS
tlsFunc := func(address raft.ServerAddress) bool {
if s.config.VerifyOutgoing {
return true
}

server := s.serverLookup.Server(address)

if server == nil {
return false
}

return server.UseTLS
// raft only talks to its own datacenter
return s.tlsConfigurator.UseTLS(s.config.Datacenter)
}
s.raftLayer = NewRaftLayer(s.config.RPCSrcAddr, s.config.RPCAdvertise, wrapper, tlsFunc)
return nil
Expand Down Expand Up @@ -1361,6 +1354,7 @@ func (s *Server) ReloadConfig(config *Config) error {
// this will error if we lose leadership while bootstrapping here.
return s.bootstrapConfigEntries(config.ConfigEntryBootstrap)
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion agent/consul/server_serf.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ func (s *Server) maybeBootstrap() {
// Retry with exponential backoff to get peer status from this server
for attempt := uint(0); attempt < maxPeerRetries; attempt++ {
if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, server.Version,
"Status.Peers", server.UseTLS, &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil {
"Status.Peers", &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil {
nextRetry := time.Duration((1 << attempt) * peerRetryBase)
s.logger.Error("Failed to confirm peer status for server (will retry).",
"server", server.Name,
Expand Down
3 changes: 1 addition & 2 deletions agent/consul/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ func testVerifyRPC(s1, s2 *Server, t *testing.T) (bool, error) {
if leader == nil {
t.Fatal("no leader")
}
return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version, leader.UseTLS)
return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version)
}

func TestServer_TLSToNoTLS(t *testing.T) {
Expand Down Expand Up @@ -1277,7 +1277,6 @@ func TestServer_TLSToFullVerify(t *testing.T) {
c.CAFile = "../../test/client_certs/rootca.crt"
c.CertFile = "../../test/client_certs/server.crt"
c.KeyFile = "../../test/client_certs/server.key"
c.VerifyIncoming = true
c.VerifyOutgoing = true
})
defer os.RemoveAll(dir1)
Expand Down
2 changes: 1 addition & 1 deletion agent/consul/stats_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewStatsFetcher(logger hclog.Logger, pool *pool.ConnPool, datacenter string
func (f *StatsFetcher) fetch(server *metadata.Server, replyCh chan *autopilot.ServerStats) {
var args struct{}
var reply autopilot.ServerStats
err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", server.UseTLS, &args, &reply)
err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", &args, &reply)
if err != nil {
f.logger.Warn("error getting server health from server",
"server", server.Name,
Expand Down
30 changes: 30 additions & 0 deletions agent/pool/peek.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"bufio"
"fmt"
"net"
)

Expand Down Expand Up @@ -47,3 +48,32 @@ func PeekForTLS(conn net.Conn) (net.Conn, bool, error) {
Conn: conn,
}, isTLS, nil
}

// PeekFirstByte will read the first byte on the conn.
//
// This function does not close the conn on an error.
//
// The returned conn has the initial read buffered internally for the purposes
// of not consuming the first byte. After that buffer is drained the conn is a
// pass through to the original conn.
func PeekFirstByte(conn net.Conn) (net.Conn, byte, error) {
br := bufio.NewReader(conn)

// Grab enough to read the first byte. Then drain the buffer so future
// reads can be direct.
peeked, err := br.Peek(1)
if err != nil {
return nil, 0, err
} else if len(peeked) == 0 {
return conn, 0, fmt.Errorf("nothing to read")
}
peeked, err = br.Peek(br.Buffered())
if err != nil {
return nil, 0, err
}

return &peekedConn{
Peeked: peeked,
Conn: conn,
}, peeked[0], nil
}
12 changes: 6 additions & 6 deletions agent/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func DialTimeoutWithRPCTypeDirectly(
}

// Check if TLS is enabled
if (useTLS) && wrapper != nil {
if useTLS && wrapper != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil {
conn.Close()
Expand Down Expand Up @@ -600,7 +600,6 @@ func (p *ConnPool) RPC(
addr net.Addr,
version int,
method string,
useTLS bool,
args interface{},
reply interface{},
) error {
Expand All @@ -611,7 +610,7 @@ func (p *ConnPool) RPC(
if method == "AutoEncrypt.Sign" {
return p.rpcInsecure(dc, nodeName, addr, method, args, reply)
} else {
return p.rpc(dc, nodeName, addr, version, method, useTLS, args, reply)
return p.rpc(dc, nodeName, addr, version, method, args, reply)
}
}

Expand All @@ -637,10 +636,11 @@ func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method
return nil
}

func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, useTLS bool, args interface{}, reply interface{}) error {
func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
p.once.Do(p.init)

// Get a usable client
useTLS := p.TLSConfigurator.UseTLS(dc)
conn, sc, err := p.getClient(dc, nodeName, addr, version, useTLS)
if err != nil {
return fmt.Errorf("rpc error getting client: %v", err)
Expand Down Expand Up @@ -671,9 +671,9 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, m

// Ping sends a Status.Ping message to the specified server and
// returns true if healthy, false if an error occurred
func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error) {
func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) {
var out struct{}
err := p.RPC(dc, nodeName, addr, version, "Status.Ping", useTLS, struct{}{}, &out)
err := p.RPC(dc, nodeName, addr, version, "Status.Ping", struct{}{}, &out)
return err == nil, err
}

Expand Down
16 changes: 13 additions & 3 deletions agent/router/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type ManagerSerfCluster interface {
// Pinger is an interface wrapping client.ConnPool to prevent a cyclic import
// dependency.
type Pinger interface {
Ping(dc, nodeName string, addr net.Addr, version int, useTLS bool) (bool, error)
Ping(dc, nodeName string, addr net.Addr, version int) (bool, error)
}

// serverList is a local copy of the struct used to maintain the list of
Expand Down Expand Up @@ -98,6 +98,10 @@ type Manager struct {
// client.ConnPool.
connPoolPinger Pinger

// serverName has the name of the managers's server. This is used to
// short-circuit pinging to itself.
serverName string

// notifyFailedBarrier is acts as a barrier to prevent queuing behind
// serverListLog and acts as a TryLock().
notifyFailedBarrier int32
Expand Down Expand Up @@ -256,7 +260,7 @@ func (m *Manager) saveServerList(l serverList) {
}

// New is the only way to safely create a new Manager struct.
func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger) (m *Manager) {
func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfCluster, connPoolPinger Pinger, serverName string) (m *Manager) {
if logger == nil {
logger = hclog.New(&hclog.LoggerOptions{})
}
Expand All @@ -267,6 +271,7 @@ func New(logger hclog.Logger, shutdownCh chan struct{}, clusterInfo ManagerSerfC
m.connPoolPinger = connPoolPinger // can't pass *consul.ConnPool: import cycle
m.rebalanceTimer = time.NewTimer(clientRPCMinReuseDuration)
m.shutdownCh = shutdownCh
m.serverName = serverName
atomic.StoreInt32(&m.offline, 1)

l := serverList{}
Expand Down Expand Up @@ -340,7 +345,12 @@ func (m *Manager) RebalanceServers() {
// while Serf detects the node has failed.
srv := l.servers[0]

ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version, srv.UseTLS)
// check to see if the manager is trying to ping itself,
// continue if that is the case.
if m.serverName != "" && srv.Name == m.serverName {
continue
}
ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version)
if ok {
foundHealthyServer = true
break
Expand Down
10 changes: 5 additions & 5 deletions agent/router/manager_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type fauxConnPool struct {
failPct float64
}

func (cp *fauxConnPool) Ping(string, string, net.Addr, int, bool) (bool, error) {
func (cp *fauxConnPool) Ping(string, string, net.Addr, int) (bool, error) {
var success bool
successProb := rand.Float64()
if successProb > cp.failPct {
Expand All @@ -53,14 +53,14 @@ func (s *fauxSerf) NumNodes() int {
func testManager() (m *Manager) {
logger := GetBufferedLogger()
shutdownCh := make(chan struct{})
m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{})
m = New(logger, shutdownCh, &fauxSerf{numNodes: 16384}, &fauxConnPool{}, "")
return m
}

func testManagerFailProb(failPct float64) (m *Manager) {
logger := GetBufferedLogger()
shutdownCh := make(chan struct{})
m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct})
m = New(logger, shutdownCh, &fauxSerf{}, &fauxConnPool{failPct: failPct}, "")
return m
}

Expand Down Expand Up @@ -179,7 +179,7 @@ func test_reconcileServerList(maxServers int) (bool, error) {
// failPct of the servers for the reconcile. This
// allows for the selected server to no longer be
// healthy for the reconcile below.
if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version, node.UseTLS); ok {
if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version); ok {
// Will still be present
healthyServers = append(healthyServers, node)
} else {
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestManagerInternal_refreshServerRebalanceTimer(t *testing.T) {
shutdownCh := make(chan struct{})

for _, s := range clusters {
m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{})
m := New(logger, shutdownCh, &fauxSerf{numNodes: s.numNodes}, &fauxConnPool{}, "")
for i := 0; i < s.numServers; i++ {
nodeName := fmt.Sprintf("s%02d", i)
m.AddServer(&metadata.Server{Name: nodeName})
Expand Down
Loading

0 comments on commit 51549bd

Please sign in to comment.