Skip to content

Commit

Permalink
add client cert auth for upstream connections
Browse files Browse the repository at this point in the history
  • Loading branch information
stlaz committed Nov 29, 2022
1 parent f9b65c7 commit 4ddb68c
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 21 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Usage of _output/kube-rbac-proxy:
--tls-reload-interval duration The interval at which to watch for TLS certificate changes, by default set to 1 minute. (default 1m0s)
--upstream string The upstream URL to proxy to once requests have successfully been authenticated and authorized.
--upstream-ca-file string The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate
--upstream-client-cert-file string If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.
--upstream-client-key-file string The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.
--upstream-force-h2c Force h2c to communiate with the upstream. This is required when the upstream speaks h2c(http/2 cleartext - insecure variant of http/2) only. For example, go-grpc server in the insecure mode, such as helm's tiller w/o TLS, speaks h2c only
-v, --v Level number for the log level verbosity
--vmodule moduleSpec comma-separated list of pattern=N settings for file-filtered logging
Expand Down
26 changes: 15 additions & 11 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,18 @@ import (
)

type config struct {
insecureListenAddress string
secureListenAddress string
upstream string
upstreamForceH2C bool
upstreamCAFile string
auth proxy.Config
tls tlsConfig
kubeconfigLocation string
allowPaths []string
ignorePaths []string
insecureListenAddress string
secureListenAddress string
upstream string
upstreamForceH2C bool
upstreamCAFile string
upstreamClientCertFile string
upstreamClientKeyFile string
auth proxy.Config
tls tlsConfig
kubeconfigLocation string
allowPaths []string
ignorePaths []string
}

type tlsConfig struct {
Expand Down Expand Up @@ -117,6 +119,8 @@ func main() {

// Auth flags
flagset.StringVar(&cfg.auth.Authentication.X509.ClientCAFile, "client-ca-file", "", "If set, any request presenting a client certificate signed by one of the authorities in the client-ca-file is authenticated with an identity corresponding to the CommonName of the client certificate.")
flagset.StringVar(&cfg.auth.Authentication.X509.UpstreamClientCertificate, "upstream-client-cert-file", "", "If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.")
flagset.StringVar(&cfg.auth.Authentication.X509.UpstreamClientKey, "upstream-client-key-file", "", "The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.")
flagset.BoolVar(&cfg.auth.Authentication.Header.Enabled, "auth-header-fields-enabled", false, "When set to true, kube-rbac-proxy adds auth-related fields to the headers of http requests sent to the upstream")
flagset.StringVar(&cfg.auth.Authentication.Header.UserFieldName, "auth-header-user-field-name", "x-remote-user", "The name of the field inside a http(2) request header to tell the upstream server about the user's name")
flagset.StringVar(&cfg.auth.Authentication.Header.GroupsFieldName, "auth-header-groups-field-name", "x-remote-groups", "The name of the field inside a http(2) request header to tell the upstream server about the user's groups")
Expand Down Expand Up @@ -245,7 +249,7 @@ For more information, please go to https://github.com/brancz/kube-rbac-proxy/iss
sarAuthorizer,
)

upstreamTransport, err := initTransport(cfg.upstreamCAFile)
upstreamTransport, err := initTransport(cfg.upstreamCAFile, cfg.upstreamClientCertFile, cfg.upstreamClientKeyFile)
if err != nil {
klog.Fatalf("Failed to set up upstream TLS connection: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion pkg/authn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ type AuthnConfig struct {

// X509Config holds public client certificate used for authentication requests if specified
type X509Config struct {
ClientCAFile string
ClientCAFile string
UpstreamClientCertificate string
UpstreamClientKey string
}

// TokenConfig holds configuration as to how token authentication is to be done
Expand Down
26 changes: 20 additions & 6 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ import (
"time"
)

func initTransport(upstreamCAFile string) (http.RoundTripper, error) {
func initTransport(upstreamCAFile, upstreamClientCertPath, upstreamClientKeyPath string) (http.RoundTripper, error) {
if upstreamCAFile == "" {
return http.DefaultTransport, nil
}

rootPEM, err := os.ReadFile(upstreamCAFile)
upstreamCAPEM, err := os.ReadFile(upstreamCAFile)
if err != nil {
return nil, fmt.Errorf("error reading upstream CA file: %v", err)
return nil, fmt.Errorf("error reading upstream CA file: %w", err)
}

roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM([]byte(rootPEM)); !ok {
var certKeyPair tls.Certificate
if len(upstreamClientCertPath) > 0 {
certKeyPair, err = tls.LoadX509KeyPair(upstreamClientCertPath, upstreamClientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
}

upstreamCAPool := x509.NewCertPool()
if ok := upstreamCAPool.AppendCertsFromPEM([]byte(upstreamCAPEM)); !ok {
return nil, errors.New("error parsing upstream CA certificate")
}

Expand All @@ -54,7 +62,13 @@ func initTransport(upstreamCAFile string) (http.RoundTripper, error) {
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: roots},
TLSClientConfig: &tls.Config{
RootCAs: upstreamCAPool,
},
}

if certKeyPair.Certificate != nil {
transport.TLSClientConfig.Certificates = []tls.Certificate{certKeyPair}
}

return transport, nil
Expand Down
161 changes: 158 additions & 3 deletions transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -16,12 +16,27 @@ limitations under the License.
package main

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
"net/http/httputil"
"path/filepath"
"testing"
"time"

certutil "k8s.io/client-go/util/cert"
"k8s.io/client-go/util/keyutil"
)

func TestInitTransportWithDefault(t *testing.T) {
roundTripper, err := initTransport("")
roundTripper, err := initTransport("", "", "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -32,7 +47,7 @@ func TestInitTransportWithDefault(t *testing.T) {
}

func TestInitTransportWithCustomCA(t *testing.T) {
roundTripper, err := initTransport("test/ca.pem")
roundTripper, err := initTransport("test/ca.pem", "", "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -42,3 +57,143 @@ func TestInitTransportWithCustomCA(t *testing.T) {
t.Error("expected root CA to be set, got nil")
}
}

type testHTTPHandler struct {
t *testing.T
gotCorrectRequest chan (bool)
}

func (h *testHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if len(req.TLS.PeerCertificates) > 0 {
h.gotCorrectRequest <- true
} else {
reqDump, _ := httputil.DumpRequest(req, false)
h.t.Logf("got request without client certificates:\n%s", reqDump)
h.t.Logf("TLS config: %#v", req.TLS)

h.gotCorrectRequest <- false
}
}

func TestInitTransportWithClientCertAuth(t *testing.T) {
requestReceived := make(chan (bool), 1)
tlsServer := http.Server{
Handler: &testHTTPHandler{t, requestReceived},
}

cert, key, err := certutil.GenerateSelfSignedCertKey("127.0.0.1", nil, nil)
if err != nil {
t.Fatalf("failed to create a new serving cert: %v", err)
}

tlsCert, err := tls.X509KeyPair(cert, key)
if err != nil {
t.Fatalf("failed to load a new serving cert: %v", err)
}

clientCert, clientKey, clientCA := generateClientCert(t)

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on secure address: %v", err)
}
defer l.Close()
tlsListener := tls.NewListener(l, &tls.Config{
Certificates: []tls.Certificate{tlsCert},
ClientCAs: clientCA,
ClientAuth: tls.RequireAndVerifyClientCert,
})
defer tlsListener.Close()

go func() {
if err := tlsServer.Serve(tlsListener); err != nil {
t.Logf("failed to run the test server: %v", err)
}
}()
defer tlsServer.Close()

tmpDir := t.TempDir()
serverCertPath := filepath.Join(tmpDir, "server.crt")
clientCertPath := filepath.Join(tmpDir, "client.crt")
clientKeyPath := filepath.Join(tmpDir, "client.key")

if err := certutil.WriteCert(serverCertPath, cert); err != nil {
t.Fatalf("failed to write server cert: %v", err)
}
if err := certutil.WriteCert(clientCertPath, clientCert); err != nil {
t.Fatalf("failed to write client cert: %v", err)
}
if err := keyutil.WriteKey(clientKeyPath, clientKey); err != nil {
t.Fatalf("failed to write client key: %v", err)
}

roundTripper, err := initTransport(serverCertPath, clientCertPath, clientKeyPath)
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
}

httpReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("https://127.0.0.1:%d", l.Addr().(*net.TCPAddr).Port), nil)
if err != nil {
t.Fatalf("failed to create an HTTP request: %v", err)
}

resp, err := roundTripper.RoundTrip(httpReq)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()

select {
case isCorrectRequest := <-requestReceived:
if !isCorrectRequest {
t.Errorf("the server did not receive the expected request")
}
case <-time.NewTimer(1 * time.Second).C:
t.Fatalf("the server did not receive any requests in the required time")
}

}

func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool) {
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}
ca, err := certutil.NewSelfSignedCACert(certutil.Config{CommonName: "testing-ca"}, privKey)
if err != nil {
t.Fatalf("failed to generate CA cert: %v", err)
}

privKeyClient, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}

certDER, err := x509.CreateCertificate(rand.Reader,
&x509.Certificate{
Subject: pkix.Name{CommonName: "testing-client"},
SerialNumber: big.NewInt(15233),
NotBefore: time.Now().Add(-5 * time.Second),
NotAfter: time.Now().Add(1 * time.Minute),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
},
ca, privKeyClient.Public(), privKey,
)
if err != nil {
t.Fatalf("failed to create a client cert: %v", err)
}

certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

caPool := x509.NewCertPool()
caPool.AddCert(ca)

privKeyPEM, err := keyutil.MarshalPrivateKeyToPEM(privKeyClient)
if err != nil {
t.Fatalf("failed to encode private key to pem: %v", err)
}

return certPEM, privKeyPEM, caPool
}

0 comments on commit 4ddb68c

Please sign in to comment.