From 8107a3bb10a514788940a98e9c31dabac6775542 Mon Sep 17 00:00:00 2001 From: Stepan Cenek Date: Wed, 24 Jul 2024 13:24:04 +0200 Subject: [PATCH] [NGNSDS-632] TLS support Signed-off-by: Stepan Cenek Signed-off-by: Tony Chen --- ais/backend/ais.go | 12 ++-- ais/backend/ht.go | 2 +- ais/daemon.go | 8 +++ ais/htcommon.go | 15 ++++- cmd/authn/mgr.go | 2 +- cmn/client.go | 37 +++++++---- cmn/config.go | 1 + cmn/tls/loader.go | 142 ++++++++++++++++++++++++++++++++++++++++ docs/configuration.md | 2 +- ext/dload/dispatcher.go | 2 +- 10 files changed, 202 insertions(+), 21 deletions(-) create mode 100644 cmn/tls/loader.go diff --git a/ais/backend/ais.go b/ais/backend/ais.go index 2ad46c043b..32e3269c74 100644 --- a/ais/backend/ais.go +++ b/ais/backend/ais.go @@ -198,7 +198,7 @@ func (m *AISbp) GetInfoInternal() (res meta.RemAisVec) { func (m *AISbp) GetInfo(clusterConf cmn.BackendConfAIS) (res meta.RemAisVec) { var ( cfg = cmn.GCO.Get() - cliPlain, cliTLS = remaisClients(&cfg.Client) + cliPlain, cliTLS = remaisClients(&cfg.ClusterConfig) ) m.mu.RLock() res.A = make([]*meta.RemAis, 0, len(m.remote)) @@ -245,8 +245,12 @@ func (m *AISbp) GetInfo(clusterConf cmn.BackendConfAIS) (res meta.RemAisVec) { return } -func remaisClients(clientConf *cmn.ClientConf) (client, clientTLS *http.Client) { - return cmn.NewDefaultClients(clientConf.Timeout.D()) +func remaisClients(cfg *cmn.ClusterConfig) (client, clientTLS *http.Client) { + sargs := cfg.Net.HTTP.ToTLS() + if cfg.Net.HTTP.UseHTTPS { + return cmn.NewDefaultClients(cfg.Client.Timeout.D(), &sargs) + } + return cmn.NewDefaultClients(cfg.Client.Timeout.D(), nil) } // A list of remote AIS URLs can contains both HTTP and HTTPS links at the @@ -257,7 +261,7 @@ func (r *remAis) init(alias string, confURLs []string, cfg *cmn.ClusterConfig) ( var ( url string remSmap, smap *meta.Smap - cliH, cliTLS = remaisClients(&cfg.Client) + cliH, cliTLS = remaisClients(cfg) ) for _, u := range confURLs { client := cliH diff --git a/ais/backend/ht.go b/ais/backend/ht.go index 7af43fcb26..89d5832496 100644 --- a/ais/backend/ht.go +++ b/ais/backend/ht.go @@ -39,7 +39,7 @@ func NewHT(t core.TargetPut, config *cmn.Config, tstats stats.Tracker) (core.Bac t: t, base: base{provider: apc.HT}, } - bp.cliH, bp.cliTLS = cmn.NewDefaultClients(config.Client.TimeoutLong.D()) + bp.cliH, bp.cliTLS = cmn.NewDefaultClients(config.Client.TimeoutLong.D(), nil) bp.init(t.Snode(), tstats) return bp, nil } diff --git a/ais/daemon.go b/ais/daemon.go index 7918b1b670..373af8a101 100644 --- a/ais/daemon.go +++ b/ais/daemon.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/aistore/cmn/debug" "github.com/NVIDIA/aistore/cmn/k8s" "github.com/NVIDIA/aistore/cmn/nlog" + "github.com/NVIDIA/aistore/cmn/tls" "github.com/NVIDIA/aistore/core/meta" "github.com/NVIDIA/aistore/fs" "github.com/NVIDIA/aistore/hk" @@ -201,6 +202,13 @@ func initDaemon(version, buildTime string) cos.Runner { // declared xactions, as per xact/api.go xreg.Init() + if config.Net.HTTP.UseHTTPS { + err = tls.Init(config.Net.HTTP.Certificate, config.Net.HTTP.CertKey) + if err != nil { + cos.ExitLogf("failed to initialize Certificate Manager: %v", err) + } + } + // primary 'host[:port]' endpoint or URL from the environment if daemon.EP = os.Getenv(env.AIS.PrimaryEP); daemon.EP != "" { scheme := "http" diff --git a/ais/htcommon.go b/ais/htcommon.go index 841368ef1a..e58d04414f 100644 --- a/ais/htcommon.go +++ b/ais/htcommon.go @@ -28,6 +28,7 @@ import ( "github.com/NVIDIA/aistore/cmn/cos" "github.com/NVIDIA/aistore/cmn/debug" "github.com/NVIDIA/aistore/cmn/nlog" + aistls "github.com/NVIDIA/aistore/cmn/tls" "github.com/NVIDIA/aistore/core/meta" "github.com/NVIDIA/aistore/ext/etl" "github.com/NVIDIA/aistore/memsys" @@ -558,7 +559,8 @@ func (server *netServer) listen(addr string, logger *log.Logger, tlsConf *tls.Co retry: if config.Net.HTTP.UseHTTPS { tag = "HTTPS" - err = server.s.ListenAndServeTLS(config.Net.HTTP.Certificate, config.Net.HTTP.CertKey) + // Listen and Serve TLS using certificates provided using the GetCertificate() instead of static files. + err = server.s.ListenAndServeTLS("", "") } else { err = server.s.ListenAndServe() } @@ -581,6 +583,9 @@ func newTLS(conf *cmn.HTTPConf) (tlsConf *tls.Config, err error) { caCert []byte clientAuth = tls.ClientAuthType(conf.ClientAuthTLS) ) + tlsConf = &tls.Config{ + ClientAuth: clientAuth, + } if clientAuth > tls.RequestClientCert { if caCert, err = os.ReadFile(conf.ClientCA); err != nil { return @@ -589,8 +594,14 @@ func newTLS(conf *cmn.HTTPConf) (tlsConf *tls.Config, err error) { if ok := pool.AppendCertsFromPEM(caCert); !ok { return nil, fmt.Errorf("tls: failed to append CA certs from PEM: %q", conf.ClientCA) } + tlsConf.ClientCAs = pool + } + if conf.Certificate != "" && conf.CertKey != "" { + if !aistls.IsLoaderSet() { + return nil, errors.New("tls: certificate manager not set") + } + tlsConf.GetCertificate = aistls.GetCert() } - tlsConf = &tls.Config{ClientAuth: clientAuth, ClientCAs: pool} return } diff --git a/cmd/authn/mgr.go b/cmd/authn/mgr.go index 139b740dc4..42bdded4ed 100644 --- a/cmd/authn/mgr.go +++ b/cmd/authn/mgr.go @@ -50,7 +50,7 @@ func newMgr(driver kvdb.Driver) (m *mgr, err error) { m = &mgr{ db: driver, } - m.clientH, m.clientTLS = cmn.NewDefaultClients(time.Duration(Conf.Timeout.Default)) + m.clientH, m.clientTLS = cmn.NewDefaultClients(time.Duration(Conf.Timeout.Default), nil) err = initializeDB(driver) return } diff --git a/cmn/client.go b/cmn/client.go index 0cc99fe977..fdcc6b7c97 100644 --- a/cmn/client.go +++ b/cmn/client.go @@ -16,6 +16,7 @@ import ( "github.com/NVIDIA/aistore/api/env" "github.com/NVIDIA/aistore/cmn/cos" + aistls "github.com/NVIDIA/aistore/cmn/tls" ) type ( @@ -105,24 +106,36 @@ func NewTLS(sargs TLSArgs) (tlsConf *tls.Config, _ error) { } } tlsConf = &tls.Config{RootCAs: pool, InsecureSkipVerify: sargs.SkipVerify} - if sargs.Certificate != "" { - cert, err := tls.LoadX509KeyPair(sargs.Certificate, sargs.Key) - if err != nil { - var hint string - if os.IsNotExist(err) { - hint = "\n(hint: check the two filenames for existence/accessibility)" + if sargs.Certificate != "" && sargs.Key != "" { + if aistls.IsLoaderSet() { + // Certificate Manager initiated as part of service startup + tlsConf.GetClientCertificate = aistls.GetClientCert() + } else { + // One-shot client probably + cert, err := tls.LoadX509KeyPair(sargs.Certificate, sargs.Key) + if err != nil { + var hint string + if os.IsNotExist(err) { + hint = "\n(hint: check the two filenames for existence/accessibility)" + } + return nil, fmt.Errorf("client tls: failed to load public/private key pair: (%q, %q)%s", + sargs.Certificate, sargs.Key, hint) } - return nil, fmt.Errorf("client tls: failed to load public/private key pair: (%q, %q)%s", - sargs.Certificate, sargs.Key, hint) + tlsConf.Certificates = []tls.Certificate{cert} } - tlsConf.Certificates = []tls.Certificate{cert} } return tlsConf, nil } -func NewDefaultClients(timeout time.Duration) (clientH, clientTLS *http.Client) { +// NewDefaultClients creates and returns a pair of HTTP clients: one without TLS and one with TLS. +// If the provided TLSArgs (sargs) is nil, the clientTLS will be a standard HTTP client without TLS. +func NewDefaultClients(timeout time.Duration, sargs *TLSArgs) (clientH, clientTLS *http.Client) { clientH = NewClient(TransportArgs{Timeout: timeout}) - clientTLS = NewClientTLS(TransportArgs{Timeout: timeout}, TLSArgs{SkipVerify: true}) + if sargs != nil { + clientTLS = NewClientTLS(TransportArgs{Timeout: timeout}, *sargs) + } else { + clientTLS = NewClient(TransportArgs{Timeout: timeout}) + } return } @@ -158,6 +171,8 @@ func EnvToTLS(sargs *TLSArgs) { sargs.Key = s } if s := os.Getenv(env.AIS.ClientCA); s != "" { + // XXX This should be RootCA for clients + // https://pkg.go.dev/crypto/tls sargs.ClientCA = s } if s := os.Getenv(env.AIS.SkipVerifyCrt); s != "" { diff --git a/cmn/config.go b/cmn/config.go index f264722551..8d7a124a0a 100644 --- a/cmn/config.go +++ b/cmn/config.go @@ -455,6 +455,7 @@ type ( UseHTTPS bool `json:"use_https"` // use HTTPS SkipVerifyCrt bool `json:"skip_verify"` // skip X509 cert verification (used with self-signed certs) Chunked bool `json:"chunked_transfer"` // (https://tools.ietf.org/html/rfc7230#page-36; not used since 02/23) + Port int `json:"port"` // AuthN port } HTTPConfToSet struct { Certificate *string `json:"server_crt,omitempty"` diff --git a/cmn/tls/loader.go b/cmn/tls/loader.go new file mode 100644 index 0000000000..ff96b3e4e7 --- /dev/null +++ b/cmn/tls/loader.go @@ -0,0 +1,142 @@ +// Package tls provides support for TLS. +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + */ +package tls + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "errors" + "sync/atomic" + "time" + + "github.com/NVIDIA/aistore/cmn/debug" + "github.com/NVIDIA/aistore/cmn/nlog" + "github.com/NVIDIA/aistore/hk" +) + +type certLoader struct { + cert atomic.Pointer[tls.Certificate] + certFile, keyFile string + retries int +} +type GetCertCB func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) +type GetClientCertCB func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) + +var ( + loader *certLoader +) + +const loadInterval = 1 * time.Hour +const loadRetries = 6 + +func Init(certFile, keyFile string) (err error) { + debug.Assertf(loader != nil, "TLS loader shouldn't be initialized more than once") + + // Allow creation of nil loader, that makes some check easier later. + if certFile != "" && keyFile != "" { + loader, err = startLoader(certFile, keyFile) + } + if err != nil { + nlog.Warningln("Fail to load TLS certificates at start up") + loader = nil + } + return err +} + +func GetCert() GetCertCB { + if loader == nil { + return nil + } + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { return loader.getCert(), nil } +} + +func GetClientCert() GetClientCertCB { + if loader == nil { + return nil + } + return func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { return loader.getCert(), nil } +} + +func IsLoaderSet() bool { + return loader != nil +} + +// startLoader will monitor files of certPath and keyPath and reload certificates +// if any of the two was updated. +func startLoader(certPath, keyPath string) (c *certLoader, err error) { + c = &certLoader{ + certFile: certPath, + keyFile: keyPath, + retries: 0, + } + // Immediately try to load existing certs. + if err := c.load(); err != nil { + return nil, err + } + + hk.Reg("tlsloader", c.housekeep, loadInterval) + return +} + +func (c *certLoader) housekeep() time.Duration { + if err := c.load(); err != nil { + c.retries++ + if c.retries > loadRetries { + nlog.Errorf("unable to load TLS certificate: %v", err) + debug.AssertNoErr(err) + } + } else { + c.retries = 0 + } + return loadInterval +} + +func (c *certLoader) getCert() *tls.Certificate { + return c.cert.Load() +} + +func (c *certLoader) load() error { + cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile) + if err != nil { + nlog.Errorln("failed to load X509 key pair:", err) + return err + } + + // Compare fingerprints of previous and current certificates, and log if updated + if prevCert := c.cert.Load(); prevCert != nil { + var ( + prevFringerprint string + curFingerprint string + ) + if prevFringerprint, err = fingerprint(prevCert); err != nil { + nlog.Errorln(err) + } + if curFingerprint, err = fingerprint(&cert); err != nil { + nlog.Errorln(err) + } + if prevFringerprint != curFingerprint { + nlog.Infof("Certificate has changed. New fingerprint: %s", curFingerprint) + } + } + c.cert.Store(&cert) + return nil +} + +// Function to compute the fingerprint of a TLS certificate +func fingerprint(tlsCert *tls.Certificate) (string, error) { + if tlsCert.Leaf == nil { + // Parse the leaf certificate if not already parsed + var err error + tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return "", errors.New("error on parsing TLS certificate") + } + } + + hash := sha256.Sum256(tlsCert.Leaf.Raw) + return hex.EncodeToString(hash[:]), nil +} diff --git a/docs/configuration.md b/docs/configuration.md index 87b0934b99..e0b6898e84 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -29,7 +29,7 @@ Majority of the configuration knobs can be changed at runtime (and at any time). For the most part, commands to view and update (CLI, cluster, node) configuration can be found [here](/docs/cli/config.md). -The [same document](docs/cli/config.md) also contains a brief theory of operation, command descriptions, numerous usage examples and more. +The [same document](/docs/cli/config.md) also contains a brief theory of operation, command descriptions, numerous usage examples and more. > **Important:** as an input, CLI accepts both plain text and JSON-formatted values. For the latter, make sure to embed the (JSON value) argument into single quotes, e.g.: diff --git a/ext/dload/dispatcher.go b/ext/dload/dispatcher.go index 659def6cfc..54b38075eb 100644 --- a/ext/dload/dispatcher.go +++ b/ext/dload/dispatcher.go @@ -60,7 +60,7 @@ type ( var g global func Init(tstats stats.Tracker, db kvdb.Driver, clientConf *cmn.ClientConf) { - g.clientH, g.clientTLS = cmn.NewDefaultClients(clientConf.TimeoutLong.D()) + g.clientH, g.clientTLS = cmn.NewDefaultClients(clientConf.TimeoutLong.D(), nil) if db == nil { // unit tests only return