Skip to content

Commit

Permalink
add client cert auth for upstream connections
Browse files Browse the repository at this point in the history
  • Loading branch information
stlaz committed Nov 30, 2022
1 parent f9b65c7 commit c4e7cd1
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 11 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Usage of _output/kube-rbac-proxy:
--tls-reload-interval duration The interval at which to watch for TLS certificate changes, by default set to 1 minute. (default 1m0s)
--upstream string The upstream URL to proxy to once requests have successfully been authenticated and authorized.
--upstream-ca-file string The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate
--upstream-client-cert-file string If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.
--upstream-client-key-file string The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.
--upstream-force-h2c Force h2c to communiate with the upstream. This is required when the upstream speaks h2c(http/2 cleartext - insecure variant of http/2) only. For example, go-grpc server in the insecure mode, such as helm's tiller w/o TLS, speaks h2c only
-v, --v Level number for the log level verbosity
--vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging
Expand Down
7 changes: 6 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ type tlsConfig struct {
minVersion string
cipherSuites []string
reloadInterval time.Duration

upstreamClientCertFile string
upstreamClientKeyFile string
}

type configfile struct {
Expand Down Expand Up @@ -114,6 +117,8 @@ func main() {
flagset.StringVar(&cfg.tls.minVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version supported. Value must match version names from https://golang.org/pkg/crypto/tls/#pkg-constants.")
flagset.StringSliceVar(&cfg.tls.cipherSuites, "tls-cipher-suites", nil, "Comma-separated list of cipher suites for the server. Values are from tls package constants (https://golang.org/pkg/crypto/tls/#pkg-constants). If omitted, the default Go cipher suites will be used")
flagset.DurationVar(&cfg.tls.reloadInterval, "tls-reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes, by default set to 1 minute.")
flagset.StringVar(&cfg.tls.upstreamClientCertFile, "upstream-client-cert-file", "", "If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.")
flagset.StringVar(&cfg.tls.upstreamClientKeyFile, "upstream-client-key-file", "", "The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.")

// Auth flags
flagset.StringVar(&cfg.auth.Authentication.X509.ClientCAFile, "client-ca-file", "", "If set, any request presenting a client certificate signed by one of the authorities in the client-ca-file is authenticated with an identity corresponding to the CommonName of the client certificate.")
Expand Down Expand Up @@ -245,7 +250,7 @@ For more information, please go to https://github.com/brancz/kube-rbac-proxy/iss
sarAuthorizer,
)

upstreamTransport, err := initTransport(cfg.upstreamCAFile)
upstreamTransport, err := initTransport(cfg.upstreamCAFile, cfg.tls.upstreamClientCertFile, cfg.tls.upstreamClientKeyFile)
if err != nil {
klog.Fatalf("Failed to set up upstream TLS connection: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/authn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ type AuthnConfig struct {

// X509Config holds public client certificate used for authentication requests if specified
type X509Config struct {
ClientCAFile string
ClientCAFile string
UpstreamClientCertificate string
UpstreamClientKey string
}

// TokenConfig holds configuration as to how token authentication is to be done
Expand Down
26 changes: 20 additions & 6 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ import (
"time"
)

func initTransport(upstreamCAFile string) (http.RoundTripper, error) {
func initTransport(upstreamCAFile, upstreamClientCertPath, upstreamClientKeyPath string) (http.RoundTripper, error) {
if upstreamCAFile == "" {
return http.DefaultTransport, nil
}

rootPEM, err := os.ReadFile(upstreamCAFile)
upstreamCAPEM, err := os.ReadFile(upstreamCAFile)
if err != nil {
return nil, fmt.Errorf("error reading upstream CA file: %v", err)
return nil, fmt.Errorf("error reading upstream CA file: %w", err)
}

roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM([]byte(rootPEM)); !ok {
var certKeyPair tls.Certificate
if len(upstreamClientCertPath) > 0 {
certKeyPair, err = tls.LoadX509KeyPair(upstreamClientCertPath, upstreamClientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
}

upstreamCAPool := x509.NewCertPool()
if ok := upstreamCAPool.AppendCertsFromPEM([]byte(upstreamCAPEM)); !ok {
return nil, errors.New("error parsing upstream CA certificate")
}

Expand All @@ -54,7 +62,13 @@ func initTransport(upstreamCAFile string) (http.RoundTripper, error) {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: roots},
TLSClientConfig: &tls.Config{
RootCAs: upstreamCAPool,
},
}

if certKeyPair.Certificate != nil {
transport.TLSClientConfig.Certificates = []tls.Certificate{certKeyPair}
}

return transport, nil
Expand Down
160 changes: 157 additions & 3 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -16,12 +16,28 @@ limitations under the License.
package main

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httputil"
"path/filepath"
"testing"
"time"

certutil "k8s.io/client-go/util/cert"
"k8s.io/client-go/util/keyutil"
)

func TestInitTransportWithDefault(t *testing.T) {
roundTripper, err := initTransport("")
roundTripper, err := initTransport("", "", "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -32,7 +48,7 @@ func TestInitTransportWithDefault(t *testing.T) {
}

func TestInitTransportWithCustomCA(t *testing.T) {
roundTripper, err := initTransport("test/ca.pem")
roundTripper, err := initTransport("test/ca.pem", "", "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -42,3 +58,141 @@ func TestInitTransportWithCustomCA(t *testing.T) {
t.Error("expected root CA to be set, got nil")
}
}

func testHTTPHandler(w http.ResponseWriter, req *http.Request) {
if len(req.TLS.PeerCertificates) > 0 {
w.Write([]byte("ok"))
return
} else {
reqDump, _ := httputil.DumpRequest(req, false)
resp := fmt.Sprintf("got request without client certificates:\n%s\n", reqDump)
resp += fmt.Sprintf("TLS config: %#v\n", req.TLS)
http.Error(w, resp, http.StatusBadRequest)
}
}

func TestInitTransportWithClientCertAuth(t *testing.T) {
tlsServer := http.Server{
Handler: http.HandlerFunc(testHTTPHandler),
}

cert, key, err := certutil.GenerateSelfSignedCertKey("127.0.0.1", nil, nil)
if err != nil {
t.Fatalf("failed to create a new serving cert: %v", err)
}

tlsCert, err := tls.X509KeyPair(cert, key)
if err != nil {
t.Fatalf("failed to load a new serving cert: %v", err)
}

clientCert, clientKey, clientCA, err := generateClientCert(t)
if err != nil {
t.Fatalf("failed to generate client cert: %v", err)
}

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on secure address: %v", err)
}
defer l.Close()
tlsListener := tls.NewListener(l, &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: clientCA,
ClientAuth: tls.RequireAndVerifyClientCert,
})
defer tlsListener.Close()

go func() {
if err := tlsServer.Serve(tlsListener); err != nil {
t.Logf("failed to run the test server: %v", err)
}
}()
defer tlsServer.Close()

tmpDir := t.TempDir()
serverCertPath := filepath.Join(tmpDir, "server.crt")
clientCertPath := filepath.Join(tmpDir, "client.crt")
clientKeyPath := filepath.Join(tmpDir, "client.key")

if err := certutil.WriteCert(serverCertPath, cert); err != nil {
t.Fatalf("failed to write server cert: %v", err)
}
if err := certutil.WriteCert(clientCertPath, clientCert); err != nil {
t.Fatalf("failed to write client cert: %v", err)
}
if err := keyutil.WriteKey(clientKeyPath, clientKey); err != nil {
t.Fatalf("failed to write client key: %v", err)
}

roundTripper, err := initTransport(serverCertPath, clientCertPath, clientKeyPath)
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
}

httpReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://127.0.0.1:%d", l.Addr().(*net.TCPAddr).Port), nil)
if err != nil {
t.Fatalf("failed to create an HTTP request: %v", err)
}

resp, err := roundTripper.RoundTrip(httpReq)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Logf("failed to read response body: %v", err)
}
t.Logf("response with failure logs:\n%s", respBody)
t.Errorf("expected the response code to be '%d', but it is '%d'", http.StatusOK, resp.StatusCode)
}
}

func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool, error) {
t.Helper()

privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}
ca, err := certutil.NewSelfSignedCACert(certutil.Config{CommonName: "testing-ca"}, privKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to generate CA cert: %v", err)
}

privKeyClient, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to generate private key: %v", err)
}

certDER, err := x509.CreateCertificate(rand.Reader,
&x509.Certificate{
Subject: pkix.Name{CommonName: "testing-client"},
SerialNumber: big.NewInt(15233),
NotBefore: time.Now().Add(-5 * time.Second),
NotAfter: time.Now().Add(1 * time.Minute),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
},
ca, privKeyClient.Public(), privKey,
)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create a client cert: %v", err)
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

caPool := x509.NewCertPool()
caPool.AddCert(ca)

privKeyPEM, err := keyutil.MarshalPrivateKeyToPEM(privKeyClient)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to encode private key to pem: %v", err)
}

return certPEM, privKeyPEM, caPool, nil
}

0 comments on commit c4e7cd1

Please sign in to comment.