From 175f8d5a59d4d58092ba2dce92fa5c5272389e4d Mon Sep 17 00:00:00 2001 From: Josef Johansson Date: Fri, 21 May 2021 15:54:42 +0200 Subject: [PATCH] feat: plugins/common/tls/config.go: Filter client certificates by DNS names Signed-off-by: Josef Johansson --- docs/TLS.md | 6 +++ plugins/common/tls/config.go | 40 ++++++++++++++++---- plugins/common/tls/config_test.go | 63 ++++++++++++++++++++++++++----- 3 files changed, 92 insertions(+), 17 deletions(-) diff --git a/docs/TLS.md b/docs/TLS.md index 355da32bb98be..74b2512f1e59d 100644 --- a/docs/TLS.md +++ b/docs/TLS.md @@ -31,6 +31,12 @@ The server TLS configuration provides support for TLS mutual authentication: ## enable mutually authenticated TLS connections. # tls_allowed_cacerts = ["/etc/telegraf/clientca.pem"] +## Set one or more allowed DNS name to enable a whitelist +## to verify incoming client certificates. +## It will go through all available SAN in the certificate, +## if of them matches the request is accepted. +# tls_allowed_dns_names = ["client.example.org"] + ## Add service certificate and key. # tls_cert = "/etc/telegraf/cert.pem" # tls_key = "/etc/telegraf/key.pem" diff --git a/plugins/common/tls/config.go b/plugins/common/tls/config.go index 586ec8fd4a417..271d63e7cac2e 100644 --- a/plugins/common/tls/config.go +++ b/plugins/common/tls/config.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "github.com/influxdata/telegraf/internal/choice" "os" "strings" ) @@ -24,12 +25,13 @@ type ClientConfig struct { // ServerConfig represents the standard server TLS config. type ServerConfig struct { - TLSCert string `toml:"tls_cert"` - TLSKey string `toml:"tls_key"` - TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"` - TLSCipherSuites []string `toml:"tls_cipher_suites"` - TLSMinVersion string `toml:"tls_min_version"` - TLSMaxVersion string `toml:"tls_max_version"` + TLSCert string `toml:"tls_cert"` + TLSKey string `toml:"tls_key"` + TLSAllowedCACerts []string `toml:"tls_allowed_cacerts"` + TLSCipherSuites []string `toml:"tls_cipher_suites"` + TLSMinVersion string `toml:"tls_min_version"` + TLSMaxVersion string `toml:"tls_max_version"` + TLSAllowedDNSNames []string `toml:"tls_allowed_dns_names"` } // TLSConfig returns a tls.Config, may be nil without error if TLS is not @@ -141,6 +143,12 @@ func (c *ServerConfig) TLSConfig() (*tls.Config, error) { "tls min version %q can't be greater than tls max version %q", tlsConfig.MinVersion, tlsConfig.MaxVersion) } + // Since clientAuth is tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + // there must be certs to validate. + if len(c.TLSAllowedCACerts) > 0 && len(c.TLSAllowedDNSNames) > 0 { + tlsConfig.VerifyPeerCertificate = c.verifyPeerCertificate + } + return tlsConfig, nil } @@ -152,8 +160,7 @@ func makeCertPool(certFiles []string) (*x509.CertPool, error) { return nil, fmt.Errorf( "could not read certificate %q: %v", certFile, err) } - ok := pool.AppendCertsFromPEM(pem) - if !ok { + if !pool.AppendCertsFromPEM(pem) { return nil, fmt.Errorf( "could not parse any PEM certificates %q: %v", certFile, err) } @@ -172,3 +179,20 @@ func loadCertificate(config *tls.Config, certFile, keyFile string) error { config.BuildNameToCertificate() return nil } + +func (c *ServerConfig) verifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + // The certificate chain is client + intermediate + root. + // Let's review the client certificate. + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return fmt.Errorf("could not validate peer certificate: %v", err) + } + + for _, name := range cert.DNSNames { + if choice.Contains(name, c.TLSAllowedDNSNames) { + return nil + } + } + + return fmt.Errorf("peer certificate not in allowed DNS Name list: %v", cert.DNSNames) +} diff --git a/plugins/common/tls/config_test.go b/plugins/common/tls/config_test.go index 2784ace6920e3..b118c48b5f912 100644 --- a/plugins/common/tls/config_test.go +++ b/plugins/common/tls/config_test.go @@ -128,12 +128,13 @@ func TestServerConfig(t *testing.T) { { name: "success", server: tls.ServerConfig{ - TLSCert: pki.ServerCertPath(), - TLSKey: pki.ServerKeyPath(), - TLSAllowedCACerts: []string{pki.CACertPath()}, - TLSCipherSuites: []string{pki.CipherSuite()}, - TLSMinVersion: pki.TLSMinVersion(), - TLSMaxVersion: pki.TLSMaxVersion(), + TLSCert: pki.ServerCertPath(), + TLSKey: pki.ServerKeyPath(), + TLSAllowedCACerts: []string{pki.CACertPath()}, + TLSCipherSuites: []string{pki.CipherSuite()}, + TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"}, + TLSMinVersion: pki.TLSMinVersion(), + TLSMaxVersion: pki.TLSMaxVersion(), }, }, { @@ -293,9 +294,10 @@ func TestConnect(t *testing.T) { } serverConfig := tls.ServerConfig{ - TLSCert: pki.ServerCertPath(), - TLSKey: pki.ServerKeyPath(), - TLSAllowedCACerts: []string{pki.CACertPath()}, + TLSCert: pki.ServerCertPath(), + TLSKey: pki.ServerKeyPath(), + TLSAllowedCACerts: []string{pki.CACertPath()}, + TLSAllowedDNSNames: []string{"localhost", "127.0.0.1"}, } serverTLSConfig, err := serverConfig.TLSConfig() @@ -323,3 +325,46 @@ func TestConnect(t *testing.T) { require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } + +func TestConnectWrongDNS(t *testing.T) { + clientConfig := tls.ClientConfig{ + TLSCA: pki.CACertPath(), + TLSCert: pki.ClientCertPath(), + TLSKey: pki.ClientKeyPath(), + } + + serverConfig := tls.ServerConfig{ + TLSCert: pki.ServerCertPath(), + TLSKey: pki.ServerKeyPath(), + TLSAllowedCACerts: []string{pki.CACertPath()}, + TLSAllowedDNSNames: []string{"localhos", "127.0.0.2"}, + } + + serverTLSConfig, err := serverConfig.TLSConfig() + require.NoError(t, err) + + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + ts.TLS = serverTLSConfig + + ts.StartTLS() + defer ts.Close() + + clientTLSConfig, err := clientConfig.TLSConfig() + require.NoError(t, err) + + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: clientTLSConfig, + }, + Timeout: 10 * time.Second, + } + + resp, err := client.Get(ts.URL) + require.Error(t, err) + if resp != nil { + err = resp.Body.Close() + require.NoError(t, err) + } +}