Skip to content

Commit

Permalink
Record tests for national cloud better
Browse files Browse the repository at this point in the history
Also add support for billing invoice ID test variable
  • Loading branch information
matthchr committed Aug 12, 2022
1 parent 5ac613d commit d4b210a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 61 deletions.
28 changes: 19 additions & 9 deletions v2/internal/config/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ const (
podNamespaceVar = "POD_NAMESPACE"
)

var DefaultEndpoint string
var DefaultAudience string
var DefaultAADAuthorityHost string

func init() {
DefaultEndpoint = cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
DefaultAudience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
DefaultAADAuthorityHost = cloud.AzurePublic.ActiveDirectoryAuthorityHost
}

// Values stores configuration values that are set for the operator.
type Values struct {
// SubscriptionID is the Azure subscription the operator will use
Expand Down Expand Up @@ -96,9 +106,9 @@ func (v Values) Cloud() cloud.Configuration {

// Special handling if we've got all the defaults just return the official public cloud
// configuration
hasDefaultAzureAuthorityHost := v.AzureAuthorityHost == "" || v.AzureAuthorityHost == cloud.AzurePublic.ActiveDirectoryAuthorityHost
hasDefaultResourceManagerEndpoint := v.ResourceManagerEndpoint == "" || v.ResourceManagerEndpoint == cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
hasDefaultResourceManagerAudience := v.ResourceManagerAudience == "" || v.ResourceManagerAudience == cloud.AzurePublic.Services[cloud.ResourceManager].Audience
hasDefaultAzureAuthorityHost := v.AzureAuthorityHost == "" || v.AzureAuthorityHost == DefaultAADAuthorityHost
hasDefaultResourceManagerEndpoint := v.ResourceManagerEndpoint == "" || v.ResourceManagerEndpoint == DefaultEndpoint
hasDefaultResourceManagerAudience := v.ResourceManagerAudience == "" || v.ResourceManagerAudience == DefaultAudience

if hasDefaultResourceManagerEndpoint && hasDefaultResourceManagerAudience && hasDefaultAzureAuthorityHost {
return cloud.AzurePublic
Expand All @@ -109,13 +119,13 @@ func (v Values) Cloud() cloud.Configuration {
resourceManagerEndpoint := v.ResourceManagerEndpoint
resourceManagerAudience := v.ResourceManagerAudience
if azureAuthorityHost == "" {
azureAuthorityHost = cloud.AzurePublic.ActiveDirectoryAuthorityHost
azureAuthorityHost = DefaultAADAuthorityHost
}
if resourceManagerAudience == "" {
resourceManagerAudience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
resourceManagerAudience = DefaultAudience
}
if resourceManagerEndpoint == "" {
resourceManagerEndpoint = cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
resourceManagerEndpoint = DefaultEndpoint
}

return cloud.Configuration{
Expand Down Expand Up @@ -150,9 +160,9 @@ func ReadFromEnvironment() (Values, error) {
result.PodNamespace = os.Getenv(podNamespaceVar)
result.TargetNamespaces = parseTargetNamespaces(os.Getenv(targetNamespacesVar))
result.SyncPeriod, err = parseSyncPeriod()
result.ResourceManagerEndpoint = envOrDefault(resourceManagerEndpointVar, cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint)
result.ResourceManagerAudience = envOrDefault(resourceManagerAudienceVar, cloud.AzurePublic.Services[cloud.ResourceManager].Audience)
result.AzureAuthorityHost = envOrDefault(azureAuthorityHostVar, cloud.AzurePublic.ActiveDirectoryAuthorityHost)
result.ResourceManagerEndpoint = envOrDefault(resourceManagerEndpointVar, DefaultEndpoint)
result.ResourceManagerAudience = envOrDefault(resourceManagerAudienceVar, DefaultAudience)
result.AzureAuthorityHost = envOrDefault(azureAuthorityHostVar, DefaultAADAuthorityHost)

if err != nil {
return result, errors.Wrapf(err, "parsing %q", syncPeriodVar)
Expand Down
22 changes: 15 additions & 7 deletions v2/internal/testcommon/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ import (
// this is shared between tests because
// instantiating it requires HTTP calls
var cachedCreds azcore.TokenCredential
var cachedSubID AzureIDs
var cachedIds AzureIDs

const TestBillingIDVar = "TEST_BILLING_ID"

type AzureIDs struct {
subscriptionID string
tenantID string
subscriptionID string
tenantID string
billingInvoiceID string
}

func getCreds() (azcore.TokenCredential, AzureIDs, error) {

if cachedCreds != nil {
return cachedCreds, cachedSubID, nil
return cachedCreds, cachedIds, nil
}

creds, err := azidentity.NewDefaultAzureCredential(nil)
Expand All @@ -46,12 +49,17 @@ func getCreds() (azcore.TokenCredential, AzureIDs, error) {
return nil, AzureIDs{}, errors.Errorf("required environment variable %q was not supplied", config.TenantIDVar)
}

// This is test specific and doesn't have a corresponding config entry. It's also optional as it's only required for
// a small number of tests. Those tests will check for it explicitly
billingInvoiceId := os.Getenv(TestBillingIDVar)

ids := AzureIDs{
subscriptionID: subscriptionID,
tenantID: tenantID,
subscriptionID: subscriptionID,
tenantID: tenantID,
billingInvoiceID: billingInvoiceId,
}

cachedCreds = creds
cachedSubID = ids
cachedIds = ids
return creds, ids, nil
}
138 changes: 93 additions & 45 deletions v2/internal/testcommon/test_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ type TestContext struct {

type PerTestContext struct {
TestContext
T *testing.T
logger logr.Logger
AzureClientRecorder *recorder.Recorder
AzureClient *genericarmclient.GenericClient
AzureSubscription string
AzureTenant string
AzureMatch *ARMMatcher
Namer ResourceNamer
NoSpaceNamer ResourceNamer
TestName string
Namespace string
Ctx context.Context
T *testing.T
logger logr.Logger
AzureClientRecorder *recorder.Recorder
AzureClient *genericarmclient.GenericClient
AzureSubscription string
AzureTenant string
AzureBillingInvoiceID string
AzureMatch *ARMMatcher
Namer ResourceNamer
NoSpaceNamer ResourceNamer
TestName string
Namespace string
Ctx context.Context
// CountsTowardsParallelLimits true means that the envtest (if any) started for this test pass counts towards the limit of
// concurrent envtests running at once. If this is false, it doesn't count towards the limit.
CountsTowardsParallelLimits bool
Expand All @@ -78,6 +79,8 @@ const ResourcePrefix = "asotest"
// either a real cluster or a kind cluster.
const LiveResourcePrefix = "asolivetest"

const DummyBillingId = "/providers/Microsoft.Billing/billingAccounts/00000000-0000-0000-0000-000000000000:00000000-0000-0000-0000-000000000000_2019-05-31/billingProfiles/0000-0000-000-000/invoiceSections/0000-0000-000-000"

func NewTestContext(
region string,
recordReplay bool,
Expand All @@ -93,28 +96,31 @@ func (tc TestContext) ForTest(t *testing.T, cfg config.Values) (PerTestContext,
logger := NewTestLogger(t)

cassetteName := "recordings/" + t.Name()
creds, azureIDs, recorder, err := createRecorder(cassetteName, tc.RecordReplay)
details, err := createRecorder(cassetteName, cfg, tc.RecordReplay)
if err != nil {
return PerTestContext{}, errors.Wrapf(err, "creating recorder")
}
// Use the recorder-specific CFG, which will URLs and AADAuthorityHost (among other things) to default
// values so that the recordings look the same regardless of which cloud you ran them in
cfg = details.cfg

// To Go SDK client reuses HTTP clients among instances by default. We add handlers to the HTTP client based on
// the specific test in question, which means that clients cannot be reused.
// We explicitly create a new http.Client so that the recording from one test doesn't
// get used for all other parallel tests.
httpClient := &http.Client{
Transport: addCountHeader(translateErrors(recorder, cassetteName, t)),
Transport: addCountHeader(translateErrors(details.recorder, cassetteName, t)),
}

var armClient *genericarmclient.GenericClient
armClient, err = genericarmclient.NewGenericClientFromHTTPClient(cfg.Cloud(), creds, httpClient, azureIDs.subscriptionID, metrics.NewARMClientMetrics())
armClient, err = genericarmclient.NewGenericClientFromHTTPClient(cfg.Cloud(), details.creds, httpClient, details.ids.subscriptionID, metrics.NewARMClientMetrics())
if err != nil {
return PerTestContext{}, errors.Wrapf(err, "failed to create generic ARM client")
}

t.Cleanup(func() {
if !t.Failed() {
err := recorder.Stop()
err := details.recorder.Stop()
if err != nil {
// cleanup function should not error-out
logger.Error(err, "unable to stop ARM client recorder")
Expand All @@ -141,19 +147,20 @@ func (tc TestContext) ForTest(t *testing.T, cfg config.Values) (PerTestContext,
context := context.Background() // we could consider using context.WithTimeout(OperationTimeout()) here

return PerTestContext{
TestContext: tc,
T: t,
logger: logger,
Namer: namer,
NoSpaceNamer: namer.WithSeparator(""),
AzureClient: armClient,
AzureSubscription: azureIDs.subscriptionID,
AzureTenant: azureIDs.tenantID,
AzureMatch: NewARMMatcher(armClient),
AzureClientRecorder: recorder,
TestName: t.Name(),
Namespace: createTestNamespaceName(t),
Ctx: context,
TestContext: tc,
T: t,
logger: logger,
Namer: namer,
NoSpaceNamer: namer.WithSeparator(""),
AzureClient: armClient,
AzureSubscription: details.ids.subscriptionID,
AzureTenant: details.ids.tenantID,
AzureBillingInvoiceID: details.ids.billingInvoiceID,
AzureMatch: NewARMMatcher(armClient),
AzureClientRecorder: details.recorder,
TestName: t.Name(),
Namespace: createTestNamespaceName(t),
Ctx: context,
}, nil
}

Expand Down Expand Up @@ -194,7 +201,14 @@ func ensureCassetteFileExists(cassetteName string) error {
return nil
}

func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredential, AzureIDs, *recorder.Recorder, error) {
type recorderDetails struct {
creds azcore.TokenCredential
ids AzureIDs
recorder *recorder.Recorder
cfg config.Values
}

func createRecorder(cassetteName string, cfg config.Values, recordReplay bool) (recorderDetails, error) {
var err error
var r *recorder.Recorder
if recordReplay {
Expand All @@ -204,7 +218,7 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
}

if err != nil {
return nil, AzureIDs{}, nil, errors.Wrapf(err, "creating recorder")
return recorderDetails{}, errors.Wrapf(err, "creating recorder")
}

var creds azcore.TokenCredential
Expand All @@ -214,14 +228,19 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
// if we are recording, we need auth
creds, azureIDs, err = getCreds()
if err != nil {
return nil, azureIDs, nil, err
return recorderDetails{}, err
}
} else {
// if we are replaying, we won't need auth
// and we use a dummy subscription ID/tenant ID
creds = mockTokenCred{}
azureIDs.tenantID = uuid.Nil.String()
azureIDs.subscriptionID = uuid.Nil.String()
azureIDs.billingInvoiceID = DummyBillingId
// Force these values to be the default
cfg.ResourceManagerEndpoint = config.DefaultEndpoint
cfg.ResourceManagerAudience = config.DefaultAudience
cfg.AzureAuthorityHost = config.DefaultAADAuthorityHost
}

// check body as well as URL/Method (copied from go-vcr documentation)
Expand Down Expand Up @@ -252,29 +271,42 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
// rewrite all request/response fields to hide the real subscription ID
// this is *not* a security measure but intended to make the tests updateable from
// any subscription, so a contributor can update the tests against their own sub.
hideID := func(s string, id string) string {
return strings.ReplaceAll(s, id, uuid.Nil.String())
hide := func(s string, id string, replacement string) string {
return strings.ReplaceAll(s, id, replacement)
}

i.Request.Body = hideRecordingData(hideID(i.Request.Body, azureIDs.subscriptionID))
i.Response.Body = hideRecordingData(hideID(i.Response.Body, azureIDs.subscriptionID))
i.Request.URL = hideID(i.Request.URL, azureIDs.subscriptionID)
// Note that this changes the cassette in-place so there's no return needed
hideCassetteString := func(cas *cassette.Interaction, id string, replacement string) {
i.Request.Body = strings.ReplaceAll(i.Request.Body, id, replacement)
i.Response.Body = strings.ReplaceAll(i.Response.Body, id, replacement)
i.Request.URL = strings.ReplaceAll(i.Request.URL, id, replacement)
}

// Hide the subscription ID
hideCassetteString(i, azureIDs.subscriptionID, uuid.Nil.String())
// Hide the tenant ID
hideCassetteString(i, azureIDs.tenantID, uuid.Nil.String())
// Hide the billing ID
hideCassetteString(i, azureIDs.billingInvoiceID, DummyBillingId)

i.Request.Body = hideRecordingData(hideID(i.Request.Body, azureIDs.tenantID))
i.Response.Body = hideRecordingData(hideID(i.Response.Body, azureIDs.tenantID))
i.Request.URL = hideID(i.Request.URL, azureIDs.tenantID)
// Hiding other sensitive fields
i.Request.Body = hideRecordingData(i.Request.Body)
i.Response.Body = hideRecordingData(i.Response.Body)
i.Request.URL = hideURLData(i.Request.URL)

for _, values := range i.Request.Headers {
for i := range values {
values[i] = hideID(values[i], azureIDs.subscriptionID)
values[i] = hideID(values[i], azureIDs.tenantID)
values[i] = hide(values[i], azureIDs.subscriptionID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.tenantID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.billingInvoiceID, DummyBillingId)
}
}

for _, values := range i.Response.Headers {
for i := range values {
values[i] = hideID(values[i], azureIDs.subscriptionID)
values[i] = hideID(values[i], azureIDs.tenantID)
values[i] = hide(values[i], azureIDs.subscriptionID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.tenantID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.billingInvoiceID, DummyBillingId)
}
}

Expand All @@ -289,7 +321,12 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
return nil
})

return creds, azureIDs, r, nil
return recorderDetails{
creds: creds,
ids: azureIDs,
recorder: r,
cfg: cfg,
}, nil
}

var requestHeadersToRemove = []string{
Expand Down Expand Up @@ -333,6 +370,9 @@ var (

// kubeConfigMatcher specifically matches base64 data returned by the AKS get keys API
kubeConfigMatcher = regexp.MustCompile(`"value": "[a-zA-Z0-9+/]+={0,2}"`)

// baseURLMatcher matches the base part of a URL
baseURLMatcher = regexp.MustCompile(`^https://[^/]+/`)
)

// hideDates replaces all ISO8601 datetimes with a fixed value
Expand All @@ -359,6 +399,10 @@ func hideKubeConfigs(s string) string {
return kubeConfigMatcher.ReplaceAllLiteralString(s, `"value": "IA=="`) // Have to replace with valid base64 data, so replace with " "
}

func hideBaseRequestURL(s string) string {
return baseURLMatcher.ReplaceAllLiteralString(s, `https://management.azure.com/`)
}

func hideRecordingData(s string) string {
result := hideDates(s)
result = hideSSHKeys(result)
Expand All @@ -369,6 +413,10 @@ func hideRecordingData(s string) string {
return result
}

func hideURLData(s string) string {
return hideBaseRequestURL(s)
}

func (tc PerTestContext) NewTestResourceGroup() *resources.ResourceGroup {
return &resources.ResourceGroup{
ObjectMeta: metav1.ObjectMeta{
Expand Down

0 comments on commit d4b210a

Please sign in to comment.