Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): auth library can talk to S2A over mTLS #10634

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 70 additions & 12 deletions auth/internal/transport/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"log"
"net"
"net/http"
"net/url"
Expand All @@ -44,10 +46,12 @@ const (
googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"

universeDomainPlaceholder = "UNIVERSE_DOMAIN"

mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
mtlsMDSKey = "/run/google-mds-mtls/client.key"
)

var (
mdsMTLSAutoConfigSource mtlsConfigSource
errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
)

Expand Down Expand Up @@ -120,7 +124,20 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
defaultTransportCreds := credentials.NewTLS(&tls.Config{
GetClientCertificate: config.clientCertSource,
})
if config.s2aAddress == "" {

var s2aAddr string
var transportCredsForS2A credentials.TransportCredentials

if config.mtlsS2AAddress != "" {
s2aAddr = config.mtlsS2AAddress
transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
if err != nil {
log.Printf("Loading MTLS MDS credentials failed: %v", err)
return defaultTransportCreds, config.endpoint, nil
}
} else if config.s2aAddress != "" {
s2aAddr = config.s2aAddress
} else {
return defaultTransportCreds, config.endpoint, nil
}

Expand All @@ -133,8 +150,9 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
}

s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
S2AAddress: config.s2aAddress,
FallbackOpts: fallbackOpts,
S2AAddress: s2aAddr,
TransportCreds: transportCredsForS2A,
FallbackOpts: fallbackOpts,
})
if err != nil {
// Use default if we cannot initialize S2A client transport credentials.
Expand All @@ -151,7 +169,19 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context,
return nil, nil, err
}

if config.s2aAddress == "" {
var s2aAddr string
var transportCredsForS2A credentials.TransportCredentials

if config.mtlsS2AAddress != "" {
s2aAddr = config.mtlsS2AAddress
transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
if err != nil {
log.Printf("Loading MTLS MDS credentials failed: %v", err)
return config.clientCertSource, nil, nil
}
} else if config.s2aAddress != "" {
s2aAddr = config.s2aAddress
} else {
return config.clientCertSource, nil, nil
}

Expand All @@ -169,12 +199,38 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context,
}

dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
S2AAddress: config.s2aAddress,
FallbackOpts: fallbackOpts,
S2AAddress: s2aAddr,
TransportCreds: transportCredsForS2A,
FallbackOpts: fallbackOpts,
})
return nil, dialTLSContextFunc, nil
}

func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
rootPEM, err := os.ReadFile(mtlsMDSRootFile)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(rootPEM)
if !ok {
return nil, errors.New("failed to load MTLS MDS root certificate")
}
// The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
// followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
// tls.X509KeyPair function as both the certificate chain and private key arguments.
cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
if err != nil {
return nil, err
}
tlsConfig := tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
return credentials.NewTLS(&tlsConfig), nil
}

func getTransportConfig(opts *Options) (*transportConfig, error) {
clientCertSource, err := GetClientCertificateProvider(opts)
if err != nil {
Expand All @@ -196,17 +252,17 @@ func getTransportConfig(opts *Options) (*transportConfig, error) {
return nil, errUniverseNotSupportedMTLS
}

s2aMTLSEndpoint := opts.DefaultMTLSEndpoint

s2aAddress := GetS2AAddress()
if s2aAddress == "" {
mtlsS2AAddress := GetMTLSS2AAddress()
if s2aAddress == "" && mtlsS2AAddress == "" {
return &defaultTransportConfig, nil
}
return &transportConfig{
clientCertSource: clientCertSource,
endpoint: endpoint,
s2aAddress: s2aAddress,
s2aMTLSEndpoint: s2aMTLSEndpoint,
mtlsS2AAddress: mtlsS2AAddress,
s2aMTLSEndpoint: opts.DefaultMTLSEndpoint,
}, nil
}

Expand Down Expand Up @@ -241,8 +297,10 @@ type transportConfig struct {
clientCertSource cert.Provider
// The corresponding endpoint to use based on client certificate source.
endpoint string
// The S2A address if it can be used, otherwise an empty string.
// The plaintext S2A address if it can be used, otherwise an empty string.
s2aAddress string
// The MTLS S2A address if it can be used, otherwise an empty string.
mtlsS2AAddress string
// The MTLS endpoint to use with S2A.
s2aMTLSEndpoint string
}
Expand Down
108 changes: 88 additions & 20 deletions auth/internal/transport/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"net/http"
"testing"
"time"

"cloud.google.com/go/auth/internal"
"cloud.google.com/go/auth/internal/transport/cert"
Expand Down Expand Up @@ -50,6 +49,20 @@ var (
return string(configStr), nil
}

validConfigRespMTLSS2A = func() (string, error) {
validConfig := mtlsConfig{
S2A: &s2aAddresses{
PlaintextAddress: "",
MTLSAddress: testMTLSS2AAddr,
},
}
configStr, err := json.Marshal(validConfig)
if err != nil {
return "", err
}
return string(configStr), nil
}

errorConfigResp = func() (string, error) {
return "", fmt.Errorf("error getting config")
}
Expand Down Expand Up @@ -250,7 +263,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
}
}

func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
func TestGetGRPCTransportConfigAndEndpoint_S2A(t *testing.T) {
testCases := []struct {
name string
opts *Options
Expand Down Expand Up @@ -324,11 +337,21 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
validConfigResp,
testRegularEndpoint,
},
{
"no client cert, MTLS S2A address not empty, no MTLS MDS cert",
&Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
validConfigRespMTLSS2A,
testRegularEndpoint,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2ARespFn
mtlsConfiguration, _ = queryConfig()
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -338,17 +361,15 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
if tc.want != endpoint {
t.Fatalf("want endpoint: %s, got %s", tc.want, endpoint)
}
// Let the cached MTLS config expire at the end of each test case.
time.Sleep(2 * time.Millisecond)
})
}
}

func TestGetHTTPTransportConfig_S2a(t *testing.T) {
func TestGetHTTPTransportConfig_S2A(t *testing.T) {
testCases := []struct {
name string
opts *Options
s2aFn func() (string, error)
s2ARespFn func() (string, error)
want string
isDialFnNil bool
}{
Expand All @@ -359,7 +380,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
ClientCertProvider: fakeClientCertSource,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
isDialFnNil: true,
},
Expand All @@ -369,16 +390,16 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
want: testMTLSEndpoint,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
},
{
name: "no client cert, S2A address empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: invalidConfigResp,
s2ARespFn: invalidConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
Expand All @@ -389,7 +410,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
Endpoint: testOverrideEndpoint,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testOverrideEndpoint,
isDialFnNil: true,
},
Expand All @@ -399,7 +420,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultMTLSEndpoint: "",
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
Expand All @@ -410,15 +431,26 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
Client: http.DefaultClient,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, MTLS S2A address not empty, no MTLS MDS cert",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2ARespFn: validConfigRespMTLSS2A,
want: testRegularEndpoint,
isDialFnNil: true,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2aFn
httpGetMetadataMTLSConfig = tc.s2ARespFn
mtlsConfiguration, _ = queryConfig()
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -431,22 +463,58 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
if want, got := tc.isDialFnNil, dialFunc == nil; want != got {
t.Errorf("expecting returned dialFunc is nil: [%v], got [%v]", tc.isDialFnNil, got)
}
// Let MTLS config expire at end of each test case.
time.Sleep(2 * time.Millisecond)
})
}
}

func TestLoadMTLSMDSTransportCreds(t *testing.T) {
testCases := []struct {
name string
rootFile string
keyFile string
wantErr bool
}{
{
name: "missing root file",
rootFile: "",
keyFile: "./testdata/mtls_mds_key.pem",
wantErr: true,
},
{
name: "missing key file",
rootFile: "./testdata/mtls_mds_root.pem",
keyFile: "",
wantErr: true,
},
{
name: "missing both root and key files",
rootFile: "",
keyFile: "",
wantErr: true,
},
{
name: "load credentials success",
rootFile: "./testdata/mtls_mds_root.pem",
keyFile: "./testdata/mtls_mds_key.pem",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := loadMTLSMDSTransportCreds(tc.rootFile, tc.keyFile)
if gotErr := err != nil; gotErr != tc.wantErr {
t.Errorf("loadMTLSMDSTransportCreds(%q, %q) got error: %v, want error: %v", tc.rootFile, tc.keyFile, gotErr, tc.wantErr)
}
})
}
}

func setupTest(t *testing.T) func() {
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry

configExpiry = time.Millisecond
t.Setenv(googleAPIUseS2AEnv, "true")

return func() {
httpGetMetadataMTLSConfig = oldHTTPGet
configExpiry = oldExpiry
}
}

Expand Down
Loading
Loading