Skip to content

Commit

Permalink
[NGNSDS-632] TLS support
Browse files Browse the repository at this point in the history
Signed-off-by: Stepan Cenek <[email protected]>
Signed-off-by: Tony Chen <[email protected]>
  • Loading branch information
Stepan Cenek authored and Nahemah1022 committed Aug 23, 2024
1 parent 1a8503f commit 8107a3b
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 21 deletions.
12 changes: 8 additions & 4 deletions ais/backend/ais.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (m *AISbp) GetInfoInternal() (res meta.RemAisVec) {
func (m *AISbp) GetInfo(clusterConf cmn.BackendConfAIS) (res meta.RemAisVec) {
var (
cfg = cmn.GCO.Get()
cliPlain, cliTLS = remaisClients(&cfg.Client)
cliPlain, cliTLS = remaisClients(&cfg.ClusterConfig)
)
m.mu.RLock()
res.A = make([]*meta.RemAis, 0, len(m.remote))
Expand Down Expand Up @@ -245,8 +245,12 @@ func (m *AISbp) GetInfo(clusterConf cmn.BackendConfAIS) (res meta.RemAisVec) {
return
}

func remaisClients(clientConf *cmn.ClientConf) (client, clientTLS *http.Client) {
return cmn.NewDefaultClients(clientConf.Timeout.D())
func remaisClients(cfg *cmn.ClusterConfig) (client, clientTLS *http.Client) {
sargs := cfg.Net.HTTP.ToTLS()
if cfg.Net.HTTP.UseHTTPS {
return cmn.NewDefaultClients(cfg.Client.Timeout.D(), &sargs)
}
return cmn.NewDefaultClients(cfg.Client.Timeout.D(), nil)
}

// A list of remote AIS URLs can contains both HTTP and HTTPS links at the
Expand All @@ -257,7 +261,7 @@ func (r *remAis) init(alias string, confURLs []string, cfg *cmn.ClusterConfig) (
var (
url string
remSmap, smap *meta.Smap
cliH, cliTLS = remaisClients(&cfg.Client)
cliH, cliTLS = remaisClients(cfg)
)
for _, u := range confURLs {
client := cliH
Expand Down
2 changes: 1 addition & 1 deletion ais/backend/ht.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func NewHT(t core.TargetPut, config *cmn.Config, tstats stats.Tracker) (core.Bac
t: t,
base: base{provider: apc.HT},
}
bp.cliH, bp.cliTLS = cmn.NewDefaultClients(config.Client.TimeoutLong.D())
bp.cliH, bp.cliTLS = cmn.NewDefaultClients(config.Client.TimeoutLong.D(), nil)
bp.init(t.Snode(), tstats)
return bp, nil
}
Expand Down
8 changes: 8 additions & 0 deletions ais/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/NVIDIA/aistore/cmn/debug"
"github.com/NVIDIA/aistore/cmn/k8s"
"github.com/NVIDIA/aistore/cmn/nlog"
"github.com/NVIDIA/aistore/cmn/tls"
"github.com/NVIDIA/aistore/core/meta"
"github.com/NVIDIA/aistore/fs"
"github.com/NVIDIA/aistore/hk"
Expand Down Expand Up @@ -201,6 +202,13 @@ func initDaemon(version, buildTime string) cos.Runner {
// declared xactions, as per xact/api.go
xreg.Init()

if config.Net.HTTP.UseHTTPS {
err = tls.Init(config.Net.HTTP.Certificate, config.Net.HTTP.CertKey)
if err != nil {
cos.ExitLogf("failed to initialize Certificate Manager: %v", err)
}
}

// primary 'host[:port]' endpoint or URL from the environment
if daemon.EP = os.Getenv(env.AIS.PrimaryEP); daemon.EP != "" {
scheme := "http"
Expand Down
15 changes: 13 additions & 2 deletions ais/htcommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/NVIDIA/aistore/cmn/cos"
"github.com/NVIDIA/aistore/cmn/debug"
"github.com/NVIDIA/aistore/cmn/nlog"
aistls "github.com/NVIDIA/aistore/cmn/tls"
"github.com/NVIDIA/aistore/core/meta"
"github.com/NVIDIA/aistore/ext/etl"
"github.com/NVIDIA/aistore/memsys"
Expand Down Expand Up @@ -558,7 +559,8 @@ func (server *netServer) listen(addr string, logger *log.Logger, tlsConf *tls.Co
retry:
if config.Net.HTTP.UseHTTPS {
tag = "HTTPS"
err = server.s.ListenAndServeTLS(config.Net.HTTP.Certificate, config.Net.HTTP.CertKey)
// Listen and Serve TLS using certificates provided using the GetCertificate() instead of static files.
err = server.s.ListenAndServeTLS("", "")
} else {
err = server.s.ListenAndServe()
}
Expand All @@ -581,6 +583,9 @@ func newTLS(conf *cmn.HTTPConf) (tlsConf *tls.Config, err error) {
caCert []byte
clientAuth = tls.ClientAuthType(conf.ClientAuthTLS)
)
tlsConf = &tls.Config{
ClientAuth: clientAuth,
}
if clientAuth > tls.RequestClientCert {
if caCert, err = os.ReadFile(conf.ClientCA); err != nil {
return
Expand All @@ -589,8 +594,14 @@ func newTLS(conf *cmn.HTTPConf) (tlsConf *tls.Config, err error) {
if ok := pool.AppendCertsFromPEM(caCert); !ok {
return nil, fmt.Errorf("tls: failed to append CA certs from PEM: %q", conf.ClientCA)
}
tlsConf.ClientCAs = pool
}
if conf.Certificate != "" && conf.CertKey != "" {
if !aistls.IsLoaderSet() {
return nil, errors.New("tls: certificate manager not set")
}
tlsConf.GetCertificate = aistls.GetCert()
}
tlsConf = &tls.Config{ClientAuth: clientAuth, ClientCAs: pool}
return
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/authn/mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func newMgr(driver kvdb.Driver) (m *mgr, err error) {
m = &mgr{
db: driver,
}
m.clientH, m.clientTLS = cmn.NewDefaultClients(time.Duration(Conf.Timeout.Default))
m.clientH, m.clientTLS = cmn.NewDefaultClients(time.Duration(Conf.Timeout.Default), nil)
err = initializeDB(driver)
return
}
Expand Down
37 changes: 26 additions & 11 deletions cmn/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/NVIDIA/aistore/api/env"
"github.com/NVIDIA/aistore/cmn/cos"
aistls "github.com/NVIDIA/aistore/cmn/tls"
)

type (
Expand Down Expand Up @@ -105,24 +106,36 @@ func NewTLS(sargs TLSArgs) (tlsConf *tls.Config, _ error) {
}
}
tlsConf = &tls.Config{RootCAs: pool, InsecureSkipVerify: sargs.SkipVerify}
if sargs.Certificate != "" {
cert, err := tls.LoadX509KeyPair(sargs.Certificate, sargs.Key)
if err != nil {
var hint string
if os.IsNotExist(err) {
hint = "\n(hint: check the two filenames for existence/accessibility)"
if sargs.Certificate != "" && sargs.Key != "" {
if aistls.IsLoaderSet() {
// Certificate Manager initiated as part of service startup
tlsConf.GetClientCertificate = aistls.GetClientCert()
} else {
// One-shot client probably
cert, err := tls.LoadX509KeyPair(sargs.Certificate, sargs.Key)
if err != nil {
var hint string
if os.IsNotExist(err) {
hint = "\n(hint: check the two filenames for existence/accessibility)"
}
return nil, fmt.Errorf("client tls: failed to load public/private key pair: (%q, %q)%s",
sargs.Certificate, sargs.Key, hint)
}
return nil, fmt.Errorf("client tls: failed to load public/private key pair: (%q, %q)%s",
sargs.Certificate, sargs.Key, hint)
tlsConf.Certificates = []tls.Certificate{cert}
}
tlsConf.Certificates = []tls.Certificate{cert}
}
return tlsConf, nil
}

func NewDefaultClients(timeout time.Duration) (clientH, clientTLS *http.Client) {
// NewDefaultClients creates and returns a pair of HTTP clients: one without TLS and one with TLS.
// If the provided TLSArgs (sargs) is nil, the clientTLS will be a standard HTTP client without TLS.
func NewDefaultClients(timeout time.Duration, sargs *TLSArgs) (clientH, clientTLS *http.Client) {
clientH = NewClient(TransportArgs{Timeout: timeout})
clientTLS = NewClientTLS(TransportArgs{Timeout: timeout}, TLSArgs{SkipVerify: true})
if sargs != nil {
clientTLS = NewClientTLS(TransportArgs{Timeout: timeout}, *sargs)
} else {
clientTLS = NewClient(TransportArgs{Timeout: timeout})
}
return
}

Expand Down Expand Up @@ -158,6 +171,8 @@ func EnvToTLS(sargs *TLSArgs) {
sargs.Key = s
}
if s := os.Getenv(env.AIS.ClientCA); s != "" {
// XXX This should be RootCA for clients
// https://pkg.go.dev/crypto/tls
sargs.ClientCA = s
}
if s := os.Getenv(env.AIS.SkipVerifyCrt); s != "" {
Expand Down
1 change: 1 addition & 0 deletions cmn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ type (
UseHTTPS bool `json:"use_https"` // use HTTPS
SkipVerifyCrt bool `json:"skip_verify"` // skip X509 cert verification (used with self-signed certs)
Chunked bool `json:"chunked_transfer"` // (https://tools.ietf.org/html/rfc7230#page-36; not used since 02/23)
Port int `json:"port"` // AuthN port
}
HTTPConfToSet struct {
Certificate *string `json:"server_crt,omitempty"`
Expand Down
142 changes: 142 additions & 0 deletions cmn/tls/loader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Package tls provides support for TLS.
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*/
package tls

import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"sync/atomic"
"time"

"github.com/NVIDIA/aistore/cmn/debug"
"github.com/NVIDIA/aistore/cmn/nlog"
"github.com/NVIDIA/aistore/hk"
)

type certLoader struct {
cert atomic.Pointer[tls.Certificate]
certFile, keyFile string
retries int
}
type GetCertCB func(_ *tls.ClientHelloInfo) (*tls.Certificate, error)
type GetClientCertCB func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error)

var (
loader *certLoader
)

const loadInterval = 1 * time.Hour
const loadRetries = 6

func Init(certFile, keyFile string) (err error) {
debug.Assertf(loader != nil, "TLS loader shouldn't be initialized more than once")

// Allow creation of nil loader, that makes some check easier later.
if certFile != "" && keyFile != "" {
loader, err = startLoader(certFile, keyFile)
}
if err != nil {
nlog.Warningln("Fail to load TLS certificates at start up")
loader = nil
}
return err
}

func GetCert() GetCertCB {
if loader == nil {
return nil
}
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { return loader.getCert(), nil }
}

func GetClientCert() GetClientCertCB {
if loader == nil {
return nil
}
return func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { return loader.getCert(), nil }
}

func IsLoaderSet() bool {
return loader != nil
}

// startLoader will monitor files of certPath and keyPath and reload certificates
// if any of the two was updated.
func startLoader(certPath, keyPath string) (c *certLoader, err error) {
c = &certLoader{
certFile: certPath,
keyFile: keyPath,
retries: 0,
}
// Immediately try to load existing certs.
if err := c.load(); err != nil {
return nil, err
}

hk.Reg("tlsloader", c.housekeep, loadInterval)
return
}

func (c *certLoader) housekeep() time.Duration {
if err := c.load(); err != nil {
c.retries++
if c.retries > loadRetries {
nlog.Errorf("unable to load TLS certificate: %v", err)
debug.AssertNoErr(err)
}
} else {
c.retries = 0
}
return loadInterval
}

func (c *certLoader) getCert() *tls.Certificate {
return c.cert.Load()
}

func (c *certLoader) load() error {
cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile)
if err != nil {
nlog.Errorln("failed to load X509 key pair:", err)
return err
}

// Compare fingerprints of previous and current certificates, and log if updated
if prevCert := c.cert.Load(); prevCert != nil {
var (
prevFringerprint string
curFingerprint string
)
if prevFringerprint, err = fingerprint(prevCert); err != nil {
nlog.Errorln(err)
}
if curFingerprint, err = fingerprint(&cert); err != nil {
nlog.Errorln(err)
}
if prevFringerprint != curFingerprint {
nlog.Infof("Certificate has changed. New fingerprint: %s", curFingerprint)
}
}
c.cert.Store(&cert)
return nil
}

// Function to compute the fingerprint of a TLS certificate
func fingerprint(tlsCert *tls.Certificate) (string, error) {
if tlsCert.Leaf == nil {
// Parse the leaf certificate if not already parsed
var err error
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return "", errors.New("error on parsing TLS certificate")
}
}

hash := sha256.Sum256(tlsCert.Leaf.Raw)
return hex.EncodeToString(hash[:]), nil
}
2 changes: 1 addition & 1 deletion docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Majority of the configuration knobs can be changed at runtime (and at any time).

For the most part, commands to view and update (CLI, cluster, node) configuration can be found [here](/docs/cli/config.md).

The [same document](docs/cli/config.md) also contains a brief theory of operation, command descriptions, numerous usage examples and more.
The [same document](/docs/cli/config.md) also contains a brief theory of operation, command descriptions, numerous usage examples and more.

> **Important:** as an input, CLI accepts both plain text and JSON-formatted values. For the latter, make sure to embed the (JSON value) argument into single quotes, e.g.:
Expand Down
2 changes: 1 addition & 1 deletion ext/dload/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type (
var g global

func Init(tstats stats.Tracker, db kvdb.Driver, clientConf *cmn.ClientConf) {
g.clientH, g.clientTLS = cmn.NewDefaultClients(clientConf.TimeoutLong.D())
g.clientH, g.clientTLS = cmn.NewDefaultClients(clientConf.TimeoutLong.D(), nil)

if db == nil { // unit tests only
return
Expand Down

0 comments on commit 8107a3b

Please sign in to comment.