Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: support Root CA rotation on server side #13307

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/pkg/transport/keepalive_listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestNewKeepAliveListener(t *testing.T) {
}
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
tlscfg, err := tlsInfo.ServerConfig()
tlscfg, err := tlsInfo.ReloadableServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
Expand Down
251 changes: 178 additions & 73 deletions client/pkg/transport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"

"go.etcd.io/etcd/client/pkg/v3/fileutil"
Expand All @@ -38,6 +40,10 @@ import (
"go.uber.org/zap"
)

const (
defaultRootCAReloadDuration = 5 * time.Minute
)

// NewListener creates a new listner.
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
Expand Down Expand Up @@ -185,17 +191,51 @@ type TLSInfo struct {
// EmptyCN indicates that the cert must have empty CN.
// If true, ClientConfig() will return an error for a cert with non empty CN.
EmptyCN bool

tlsConfig atomic.Value // *tls.Config
refreshOnce sync.Once
RefreshDuration time.Duration
EnableRootCAReload bool
refreshDone chan struct{}
}

func (info TLSInfo) String() string {
func (info *TLSInfo) String() string {
return fmt.Sprintf("cert = %s, key = %s, client-cert=%s, client-key=%s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.ClientCertFile, info.ClientKeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
}

func (info TLSInfo) Empty() bool {
func (info *TLSInfo) Empty() bool {
return info.CertFile == "" && info.KeyFile == ""
}

func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
func (info *TLSInfo) Clone() *TLSInfo {
return &TLSInfo{
CertFile: info.CertFile,
KeyFile: info.KeyFile,
ClientCertFile: info.ClientCertFile,
ClientKeyFile: info.ClientKeyFile,
TrustedCAFile: info.TrustedCAFile,
ClientCertAuth: info.ClientCertAuth,
CRLFile: info.CRLFile,
InsecureSkipVerify: info.InsecureSkipVerify,
SkipClientSANVerify: info.SkipClientSANVerify,
ServerName: info.ServerName,
HandshakeFailure: info.HandshakeFailure,
CipherSuites: info.CipherSuites,
selfCert: info.selfCert,
parseFunc: info.parseFunc,
AllowedCN: info.AllowedCN,
AllowedHostname: info.AllowedHostname,
Logger: info.Logger,
EmptyCN: info.EmptyCN,
RefreshDuration: info.RefreshDuration,
EnableRootCAReload: info.EnableRootCAReload,
}
}

func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info *TLSInfo, err error) {
if info == nil {
info = &TLSInfo{}
}
info.Logger = lg
if selfSignedCertValidity == 0 {
err = fmt.Errorf("selfSignedCertValidity is invalid,it should be greater than 0")
Expand Down Expand Up @@ -334,6 +374,87 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali
return SelfCert(lg, dirpath, hosts, selfSignedCertValidity)
}

func (info *TLSInfo) startRefresh() {
info.refreshOnce.Do(
func() {
info.loadServerTlsConfig()
if info.EnableRootCAReload {
if info.RefreshDuration == 0 {
info.RefreshDuration = defaultRootCAReloadDuration
}
info.refreshDone = make(chan struct{})
go info.tlsConfigRefreshLoop()
}
},
)
}

func (info *TLSInfo) loadServerTlsConfig() {
if info.Logger != nil {
info.Logger.Info("tls config reload from files")
}
cfg, err := info.serverConfig()
if err == nil {
info.tlsConfig.Store(cfg)
} else {
if info.Logger != nil {
info.Logger.Error("reload tls config error:", zap.Error(err))
}
}
}

func (info *TLSInfo) tlsConfigRefreshLoop() {
ticker := time.NewTicker(info.RefreshDuration)
defer ticker.Stop()
for {
select {
case <-ticker.C:
info.loadServerTlsConfig()
case <-info.refreshDone:
return
}
}
}

func (info *TLSInfo) getClientCertificate() (*tls.Certificate, error) {
certFile, keyFile := info.CertFile, info.KeyFile
if info.ClientCertFile != "" {
certFile, keyFile = info.ClientCertFile, info.ClientKeyFile
}
return info.getCertificates(certFile, keyFile)
}

func (info *TLSInfo) getServerCertificates() (*tls.Certificate, error) {
return info.getCertificates(info.CertFile, info.KeyFile)
}

func (info *TLSInfo) getCertificates(certFile, keyFile string) (*tls.Certificate, error) {
cert, err := tlsutil.NewCert(certFile, keyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find cert files",
zap.String("cert-file", certFile),
zap.String("key-file", keyFile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create peer certificate",
zap.String("cert-file", certFile),
zap.String("key-file", keyFile),
zap.Error(err),
)
}
}
if err != nil {
return nil, err
}
return cert, err
}

// baseConfig is called on initial TLS handshake start.
//
// Previously,
Expand All @@ -354,7 +475,7 @@ func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertVali
// handshake, in order to trigger (*tls.Config).GetCertificate and populate
// rest of the certificates on every new TLS connection, even when client
// SNI is empty (e.g. cert only includes IPs).
func (info TLSInfo) baseConfig() (*tls.Config, error) {
func (info *TLSInfo) baseConfig() (*tls.Config, error) {
if info.KeyFile == "" || info.CertFile == "" {
return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
}
Expand Down Expand Up @@ -415,82 +536,15 @@ func (info TLSInfo) baseConfig() (*tls.Config, error) {
return errors.New("client certificate authentication failed")
}
}

// this only reloads certs when there's a client request
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find peer cert files",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create peer certificate",
zap.String("cert-file", info.CertFile),
zap.String("key-file", info.KeyFile),
zap.Error(err),
)
}
}
return cert, err
}
cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
certfile, keyfile := info.CertFile, info.KeyFile
if info.ClientCertFile != "" {
certfile, keyfile = info.ClientCertFile, info.ClientKeyFile
}
cert, err = tlsutil.NewCert(certfile, keyfile, info.parseFunc)
if os.IsNotExist(err) {
if info.Logger != nil {
info.Logger.Warn(
"failed to find client cert files",
zap.String("cert-file", certfile),
zap.String("key-file", keyfile),
zap.Error(err),
)
}
} else if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"failed to create client certificate",
zap.String("cert-file", certfile),
zap.String("key-file", keyfile),
zap.Error(err),
)
}
}
return cert, err
}
return cfg, nil
}

// cafiles returns a list of CA file paths.
func (info TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ServerConfig generates a tls.Config object for use by an HTTP server.
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
func (info *TLSInfo) serverConfig() (*tls.Config, error) {
yishuT marked this conversation as resolved.
Show resolved Hide resolved
cfg, err := info.baseConfig()
if err != nil {
return nil, err
}

if info.Logger == nil {
info.Logger = zap.NewNop()
}

cfg.ClientAuth = tls.NoClientCert
if info.TrustedCAFile != "" || info.ClientCertAuth {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
Expand All @@ -515,11 +569,44 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) {
// setting Max TLS version to TLS 1.2 for go 1.13
cfg.MaxVersion = tls.VersionTLS12

certs, err := info.getServerCertificates()
if err != nil {
return nil, err
}
cfg.Certificates = []tls.Certificate{*certs}
return cfg, nil
}

// cafiles returns a list of CA file paths.
func (info *TLSInfo) cafiles() []string {
cs := make([]string, 0)
if info.TrustedCAFile != "" {
cs = append(cs, info.TrustedCAFile)
}
return cs
}

// ReloadableServerConfig generates a tls.Config object for use by an HTTP server.
func (info *TLSInfo) ReloadableServerConfig() (*tls.Config, error) {
info.startRefresh()
return &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
cfg, ok := info.tlsConfig.Load().(*tls.Config)
if !ok {
return nil, errors.New("server tls configuration not ready")
}
return cfg.Clone(), nil
},
// Needed to tell go http server to serve http2
NextProtos: []string{"h2"},
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return info.getClientCertificate()
},
}, nil
}

// ClientConfig generates a tls.Config object for use by an HTTP client.
func (info TLSInfo) ClientConfig() (*tls.Config, error) {
func (info *TLSInfo) ClientConfig() (*tls.Config, error) {
var cfg *tls.Config
var err error

Expand Down Expand Up @@ -574,9 +661,27 @@ func (info TLSInfo) ClientConfig() (*tls.Config, error) {
// setting Max TLS version to TLS 1.2 for go 1.13
cfg.MaxVersion = tls.VersionTLS12

cert, err := info.getClientCertificate()
if err != nil {
if info.Logger != nil {
info.Logger.Warn(
"cannot create client certificate",
zap.Error(err),
)
}
} else {
cfg.Certificates = []tls.Certificate{*cert}
}

return cfg, nil
}

func (info *TLSInfo) Close() {
yishuT marked this conversation as resolved.
Show resolved Hide resolved
if info.refreshDone != nil {
close(info.refreshDone)
}
}

// IsClosedConnError returns true if the error is from closing listener, cmux.
// copied from golang.org/x/net/http2/http2.go
func IsClosedConnError(err error) bool {
Expand Down
Loading