Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record tests for national cloud better #2445

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions v2/internal/config/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ const (
podNamespaceVar = "POD_NAMESPACE"
)

// These are hardcoded because the init function that initializes them in azcore isn't in /cloud it's in /arm which
// we don't import.

var DefaultEndpoint = "https://management.azure.com"
var DefaultAudience = "https://management.core.windows.net/"
var DefaultAADAuthorityHost = "https://login.microsoftonline.com/"

// Values stores configuration values that are set for the operator.
type Values struct {
// SubscriptionID is the Azure subscription the operator will use
Expand Down Expand Up @@ -96,9 +103,9 @@ func (v Values) Cloud() cloud.Configuration {

// Special handling if we've got all the defaults just return the official public cloud
// configuration
hasDefaultAzureAuthorityHost := v.AzureAuthorityHost == "" || v.AzureAuthorityHost == cloud.AzurePublic.ActiveDirectoryAuthorityHost
hasDefaultResourceManagerEndpoint := v.ResourceManagerEndpoint == "" || v.ResourceManagerEndpoint == cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
hasDefaultResourceManagerAudience := v.ResourceManagerAudience == "" || v.ResourceManagerAudience == cloud.AzurePublic.Services[cloud.ResourceManager].Audience
hasDefaultAzureAuthorityHost := v.AzureAuthorityHost == "" || v.AzureAuthorityHost == DefaultAADAuthorityHost
hasDefaultResourceManagerEndpoint := v.ResourceManagerEndpoint == "" || v.ResourceManagerEndpoint == DefaultEndpoint
hasDefaultResourceManagerAudience := v.ResourceManagerAudience == "" || v.ResourceManagerAudience == DefaultAudience

if hasDefaultResourceManagerEndpoint && hasDefaultResourceManagerAudience && hasDefaultAzureAuthorityHost {
return cloud.AzurePublic
Expand All @@ -109,13 +116,13 @@ func (v Values) Cloud() cloud.Configuration {
resourceManagerEndpoint := v.ResourceManagerEndpoint
resourceManagerAudience := v.ResourceManagerAudience
if azureAuthorityHost == "" {
azureAuthorityHost = cloud.AzurePublic.ActiveDirectoryAuthorityHost
azureAuthorityHost = DefaultAADAuthorityHost
}
if resourceManagerAudience == "" {
resourceManagerAudience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
resourceManagerAudience = DefaultAudience
}
if resourceManagerEndpoint == "" {
resourceManagerEndpoint = cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint
resourceManagerEndpoint = DefaultEndpoint
}

return cloud.Configuration{
Expand Down Expand Up @@ -150,9 +157,9 @@ func ReadFromEnvironment() (Values, error) {
result.PodNamespace = os.Getenv(podNamespaceVar)
result.TargetNamespaces = parseTargetNamespaces(os.Getenv(targetNamespacesVar))
result.SyncPeriod, err = parseSyncPeriod()
result.ResourceManagerEndpoint = envOrDefault(resourceManagerEndpointVar, cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint)
result.ResourceManagerAudience = envOrDefault(resourceManagerAudienceVar, cloud.AzurePublic.Services[cloud.ResourceManager].Audience)
result.AzureAuthorityHost = envOrDefault(azureAuthorityHostVar, cloud.AzurePublic.ActiveDirectoryAuthorityHost)
result.ResourceManagerEndpoint = envOrDefault(resourceManagerEndpointVar, DefaultEndpoint)
result.ResourceManagerAudience = envOrDefault(resourceManagerAudienceVar, DefaultAudience)
result.AzureAuthorityHost = envOrDefault(azureAuthorityHostVar, DefaultAADAuthorityHost)

if err != nil {
return result, errors.Wrapf(err, "parsing %q", syncPeriodVar)
Expand Down
2 changes: 2 additions & 0 deletions v2/internal/config/vars_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"reflect"
"testing"

// Importing this for side effects is required as it initializes cloud.AzurePublic
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
. "github.com/onsi/gomega"

Expand Down
22 changes: 15 additions & 7 deletions v2/internal/testcommon/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ import (
// this is shared between tests because
// instantiating it requires HTTP calls
var cachedCreds azcore.TokenCredential
var cachedSubID AzureIDs
var cachedIds AzureIDs

const TestBillingIDVar = "TEST_BILLING_ID"

type AzureIDs struct {
subscriptionID string
tenantID string
subscriptionID string
tenantID string
billingInvoiceID string
}

func getCreds() (azcore.TokenCredential, AzureIDs, error) {

if cachedCreds != nil {
return cachedCreds, cachedSubID, nil
return cachedCreds, cachedIds, nil
}

creds, err := azidentity.NewDefaultAzureCredential(nil)
Expand All @@ -46,12 +49,17 @@ func getCreds() (azcore.TokenCredential, AzureIDs, error) {
return nil, AzureIDs{}, errors.Errorf("required environment variable %q was not supplied", config.TenantIDVar)
}

// This is test specific and doesn't have a corresponding config entry. It's also optional as it's only required for
// a small number of tests. Those tests will check for it explicitly
billingInvoiceId := os.Getenv(TestBillingIDVar)

ids := AzureIDs{
subscriptionID: subscriptionID,
tenantID: tenantID,
subscriptionID: subscriptionID,
tenantID: tenantID,
billingInvoiceID: billingInvoiceId,
}

cachedCreds = creds
cachedSubID = ids
cachedIds = ids
return creds, ids, nil
}
138 changes: 93 additions & 45 deletions v2/internal/testcommon/test_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ type TestContext struct {

type PerTestContext struct {
TestContext
T *testing.T
logger logr.Logger
AzureClientRecorder *recorder.Recorder
AzureClient *genericarmclient.GenericClient
AzureSubscription string
AzureTenant string
AzureMatch *ARMMatcher
Namer ResourceNamer
NoSpaceNamer ResourceNamer
TestName string
Namespace string
Ctx context.Context
T *testing.T
logger logr.Logger
AzureClientRecorder *recorder.Recorder
AzureClient *genericarmclient.GenericClient
AzureSubscription string
AzureTenant string
AzureBillingInvoiceID string
AzureMatch *ARMMatcher
Namer ResourceNamer
NoSpaceNamer ResourceNamer
TestName string
Namespace string
Ctx context.Context
// CountsTowardsParallelLimits true means that the envtest (if any) started for this test pass counts towards the limit of
// concurrent envtests running at once. If this is false, it doesn't count towards the limit.
CountsTowardsParallelLimits bool
Expand All @@ -78,6 +79,8 @@ const ResourcePrefix = "asotest"
// either a real cluster or a kind cluster.
const LiveResourcePrefix = "asolivetest"

const DummyBillingId = "/providers/Microsoft.Billing/billingAccounts/00000000-0000-0000-0000-000000000000:00000000-0000-0000-0000-000000000000_2019-05-31/billingProfiles/0000-0000-000-000/invoiceSections/0000-0000-000-000"

func NewTestContext(
region string,
recordReplay bool,
Expand All @@ -93,28 +96,31 @@ func (tc TestContext) ForTest(t *testing.T, cfg config.Values) (PerTestContext,
logger := NewTestLogger(t)

cassetteName := "recordings/" + t.Name()
creds, azureIDs, recorder, err := createRecorder(cassetteName, tc.RecordReplay)
details, err := createRecorder(cassetteName, cfg, tc.RecordReplay)
if err != nil {
return PerTestContext{}, errors.Wrapf(err, "creating recorder")
}
// Use the recorder-specific CFG, which will force URLs and AADAuthorityHost (among other things) to default
// values so that the recordings look the same regardless of which cloud you ran them in
cfg = details.cfg

// To Go SDK client reuses HTTP clients among instances by default. We add handlers to the HTTP client based on
// the specific test in question, which means that clients cannot be reused.
// We explicitly create a new http.Client so that the recording from one test doesn't
// get used for all other parallel tests.
httpClient := &http.Client{
Transport: addCountHeader(translateErrors(recorder, cassetteName, t)),
Transport: addCountHeader(translateErrors(details.recorder, cassetteName, t)),
}

var armClient *genericarmclient.GenericClient
armClient, err = genericarmclient.NewGenericClientFromHTTPClient(cfg.Cloud(), creds, httpClient, azureIDs.subscriptionID, metrics.NewARMClientMetrics())
armClient, err = genericarmclient.NewGenericClientFromHTTPClient(cfg.Cloud(), details.creds, httpClient, details.ids.subscriptionID, metrics.NewARMClientMetrics())
if err != nil {
return PerTestContext{}, errors.Wrapf(err, "failed to create generic ARM client")
}

t.Cleanup(func() {
if !t.Failed() {
err := recorder.Stop()
err := details.recorder.Stop()
if err != nil {
// cleanup function should not error-out
logger.Error(err, "unable to stop ARM client recorder")
Expand All @@ -141,19 +147,20 @@ func (tc TestContext) ForTest(t *testing.T, cfg config.Values) (PerTestContext,
context := context.Background() // we could consider using context.WithTimeout(OperationTimeout()) here

return PerTestContext{
TestContext: tc,
T: t,
logger: logger,
Namer: namer,
NoSpaceNamer: namer.WithSeparator(""),
AzureClient: armClient,
AzureSubscription: azureIDs.subscriptionID,
AzureTenant: azureIDs.tenantID,
AzureMatch: NewARMMatcher(armClient),
AzureClientRecorder: recorder,
TestName: t.Name(),
Namespace: createTestNamespaceName(t),
Ctx: context,
TestContext: tc,
T: t,
logger: logger,
Namer: namer,
NoSpaceNamer: namer.WithSeparator(""),
AzureClient: armClient,
AzureSubscription: details.ids.subscriptionID,
AzureTenant: details.ids.tenantID,
AzureBillingInvoiceID: details.ids.billingInvoiceID,
AzureMatch: NewARMMatcher(armClient),
AzureClientRecorder: details.recorder,
TestName: t.Name(),
Namespace: createTestNamespaceName(t),
Ctx: context,
}, nil
}

Expand Down Expand Up @@ -194,7 +201,14 @@ func ensureCassetteFileExists(cassetteName string) error {
return nil
}

func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredential, AzureIDs, *recorder.Recorder, error) {
type recorderDetails struct {
creds azcore.TokenCredential
ids AzureIDs
recorder *recorder.Recorder
cfg config.Values
}

func createRecorder(cassetteName string, cfg config.Values, recordReplay bool) (recorderDetails, error) {
var err error
var r *recorder.Recorder
if recordReplay {
Expand All @@ -204,7 +218,7 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
}

if err != nil {
return nil, AzureIDs{}, nil, errors.Wrapf(err, "creating recorder")
return recorderDetails{}, errors.Wrapf(err, "creating recorder")
}

var creds azcore.TokenCredential
Expand All @@ -214,14 +228,19 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
// if we are recording, we need auth
creds, azureIDs, err = getCreds()
if err != nil {
return nil, azureIDs, nil, err
return recorderDetails{}, err
}
} else {
// if we are replaying, we won't need auth
// and we use a dummy subscription ID/tenant ID
creds = mockTokenCred{}
azureIDs.tenantID = uuid.Nil.String()
azureIDs.subscriptionID = uuid.Nil.String()
azureIDs.billingInvoiceID = DummyBillingId
// Force these values to be the default
cfg.ResourceManagerEndpoint = config.DefaultEndpoint
cfg.ResourceManagerAudience = config.DefaultAudience
cfg.AzureAuthorityHost = config.DefaultAADAuthorityHost
}

// check body as well as URL/Method (copied from go-vcr documentation)
Expand Down Expand Up @@ -252,29 +271,42 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
// rewrite all request/response fields to hide the real subscription ID
// this is *not* a security measure but intended to make the tests updateable from
// any subscription, so a contributor can update the tests against their own sub.
hideID := func(s string, id string) string {
return strings.ReplaceAll(s, id, uuid.Nil.String())
hide := func(s string, id string, replacement string) string {
return strings.ReplaceAll(s, id, replacement)
}

i.Request.Body = hideRecordingData(hideID(i.Request.Body, azureIDs.subscriptionID))
i.Response.Body = hideRecordingData(hideID(i.Response.Body, azureIDs.subscriptionID))
i.Request.URL = hideID(i.Request.URL, azureIDs.subscriptionID)
// Note that this changes the cassette in-place so there's no return needed
hideCassetteString := func(cas *cassette.Interaction, id string, replacement string) {
i.Request.Body = strings.ReplaceAll(cas.Request.Body, id, replacement)
i.Response.Body = strings.ReplaceAll(cas.Response.Body, id, replacement)
i.Request.URL = strings.ReplaceAll(cas.Request.URL, id, replacement)
}

// Hide the subscription ID
hideCassetteString(i, azureIDs.subscriptionID, uuid.Nil.String())
// Hide the tenant ID
hideCassetteString(i, azureIDs.tenantID, uuid.Nil.String())
// Hide the billing ID
hideCassetteString(i, azureIDs.billingInvoiceID, DummyBillingId)

i.Request.Body = hideRecordingData(hideID(i.Request.Body, azureIDs.tenantID))
i.Response.Body = hideRecordingData(hideID(i.Response.Body, azureIDs.tenantID))
i.Request.URL = hideID(i.Request.URL, azureIDs.tenantID)
// Hiding other sensitive fields
i.Request.Body = hideRecordingData(i.Request.Body)
i.Response.Body = hideRecordingData(i.Response.Body)
i.Request.URL = hideURLData(i.Request.URL)

for _, values := range i.Request.Headers {
for i := range values {
values[i] = hideID(values[i], azureIDs.subscriptionID)
values[i] = hideID(values[i], azureIDs.tenantID)
values[i] = hide(values[i], azureIDs.subscriptionID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.tenantID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.billingInvoiceID, DummyBillingId)
}
}

for _, values := range i.Response.Headers {
for i := range values {
values[i] = hideID(values[i], azureIDs.subscriptionID)
values[i] = hideID(values[i], azureIDs.tenantID)
values[i] = hide(values[i], azureIDs.subscriptionID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.tenantID, uuid.Nil.String())
values[i] = hide(values[i], azureIDs.billingInvoiceID, DummyBillingId)
}
}

Expand All @@ -289,7 +321,12 @@ func createRecorder(cassetteName string, recordReplay bool) (azcore.TokenCredent
return nil
})

return creds, azureIDs, r, nil
return recorderDetails{
creds: creds,
ids: azureIDs,
recorder: r,
cfg: cfg,
}, nil
}

var requestHeadersToRemove = []string{
Expand Down Expand Up @@ -333,6 +370,9 @@ var (

// kubeConfigMatcher specifically matches base64 data returned by the AKS get keys API
kubeConfigMatcher = regexp.MustCompile(`"value": "[a-zA-Z0-9+/]+={0,2}"`)

// baseURLMatcher matches the base part of a URL
baseURLMatcher = regexp.MustCompile(`^https://[^/]+/`)
)

// hideDates replaces all ISO8601 datetimes with a fixed value
Expand All @@ -359,6 +399,10 @@ func hideKubeConfigs(s string) string {
return kubeConfigMatcher.ReplaceAllLiteralString(s, `"value": "IA=="`) // Have to replace with valid base64 data, so replace with " "
}

func hideBaseRequestURL(s string) string {
return baseURLMatcher.ReplaceAllLiteralString(s, `https://management.azure.com/`)
}

func hideRecordingData(s string) string {
result := hideDates(s)
result = hideSSHKeys(result)
Expand All @@ -369,6 +413,10 @@ func hideRecordingData(s string) string {
return result
}

func hideURLData(s string) string {
return hideBaseRequestURL(s)
}

func (tc PerTestContext) NewTestResourceGroup() *resources.ResourceGroup {
return &resources.ResourceGroup{
ObjectMeta: metav1.ObjectMeta{
Expand Down