diff --git a/provider/pkg/azure/client_azcore.go b/provider/pkg/azure/client_azcore.go index e16a37b42611..8fbe1a9e70b0 100644 --- a/provider/pkg/azure/client_azcore.go +++ b/provider/pkg/azure/client_azcore.go @@ -37,20 +37,15 @@ type azCoreClient struct { updatePollingIntervalSeconds int64 } -func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, azureCloud cloud.Configuration, opts *arm.ClientOptions, -) (AzureClient, error) { - // Hook our logging up to the azcore logger. - log.SetListener(func(event log.Event, msg string) { - // Retry logging is very verbose and the number of the retry attempt is already contained - // in the response event. - if event != log.EventRetryPolicy { - logging.V(9).Infof("[azcore] %v: %s", event, msg) - } - }) - +func initPipelineOpts(azureCloud cloud.Configuration, opts *arm.ClientOptions) *arm.ClientOptions { if opts == nil { - opts = &arm.ClientOptions{} + opts = &arm.ClientOptions{ + ClientOptions: policy.ClientOptions{ + Cloud: azureCloud, + }, + } } + // azcore logging will only happen at log level 9. opts.Logging.IncludeBody = true @@ -72,7 +67,6 @@ func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, a http.StatusServiceUnavailable, // 503 http.StatusGatewayTimeout, // 504 } - opts.Retry.ShouldRetry = func(resp *http.Response, err error) bool { if err != nil { return true @@ -81,14 +75,29 @@ func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, a if runtime.HasStatusCode(resp, retryableStatusCodes...) { return true } - if shouldRetryConflict(resp) { return true } - return false } + return opts +} + +// NewAzCoreClient creates a new AzureClient using the azcore SDK. For general use, leave userOpts +// nil to use the default options. If you do set it, make sure to set its ClientOptions.Cloud field. +func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, azureCloud cloud.Configuration, userOpts *arm.ClientOptions, +) (AzureClient, error) { + // Hook our logging up to the azcore logger. + log.SetListener(func(event log.Event, msg string) { + // Retry logging is very verbose and the number of the retry attempt is already contained + // in the response event. + if event != log.EventRetryPolicy { + logging.V(9).Infof("[azcore] %v: %s", event, msg) + } + }) + + opts := initPipelineOpts(azureCloud, userOpts) pipeline, err := armruntime.NewPipeline("pulumi-azure-native", version.Version, tokenCredential, runtime.PipelineOptions{}, opts) if err != nil { diff --git a/provider/pkg/azure/client_azcore_test.go b/provider/pkg/azure/client_azcore_test.go index 57feb2b86890..0c06535ec6e7 100644 --- a/provider/pkg/azure/client_azcore_test.go +++ b/provider/pkg/azure/client_azcore_test.go @@ -10,6 +10,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" @@ -20,6 +21,79 @@ import ( "github.com/stretchr/testify/require" ) +func TestInitPipelineOpts(t *testing.T) { + t.Run("retry delays", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.InDelta(t, 20*time.Second, opts.Retry.RetryDelay, 10.0) + assert.InDelta(t, 120*time.Second, opts.Retry.MaxRetryDelay, 30.0) + }) + + t.Run("cloud is public", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.Equal(t, cloud.AzurePublic, opts.ClientOptions.Cloud) + }) + + t.Run("cloud is usgov", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzureGovernment, nil) + assert.Equal(t, cloud.AzureGovernment, opts.ClientOptions.Cloud) + }) + + t.Run("cloud is china", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzureChina, nil) + assert.Equal(t, cloud.AzureChina, opts.ClientOptions.Cloud) + }) + + t.Run("should retry", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.NotNil(t, opts.Retry.ShouldRetry) + }) + + t.Run("retries on 408 timeout", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusRequestTimeout}, nil)) + }) + + t.Run("retries on 409 conflict when another operation is in progress", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + header := http.Header{} + header.Add("x-ms-error-code", "AnotherOperationInProgress") + assert.True(t, opts.Retry.ShouldRetry(&http.Response{ + StatusCode: http.StatusConflict, + Header: header, + }, nil)) + }) + + t.Run("doesn't retry on 409 conflict when no other operation is in progress", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.False(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusConflict}, nil)) + }) + + t.Run("retries on 429 too many requests", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusTooManyRequests}, nil)) + }) + + t.Run("retries on 500 internal server error", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusInternalServerError}, nil)) + }) + + t.Run("retries on 502 bad gateway", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusBadGateway}, nil)) + }) + + t.Run("retries on 503 service unavailable", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil)) + }) + + t.Run("retries on 504 gateway timeout", func(t *testing.T) { + opts := initPipelineOpts(cloud.AzurePublic, nil) + assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusGatewayTimeout}, nil)) + }) +} + func TestNormalizeLocationHeader(t *testing.T) { const host = "https://management.azure.com" const apiVersion = "2022-09-01" diff --git a/provider/pkg/provider/provider_test.go b/provider/pkg/provider/provider_test.go index 86c6c20a9a7d..e06cab9614ee 100644 --- a/provider/pkg/provider/provider_test.go +++ b/provider/pkg/provider/provider_test.go @@ -13,6 +13,7 @@ import ( "github.com/pulumi/pulumi-azure-native/v2/provider/pkg/provider/crud" "github.com/pulumi/pulumi-azure-native/v2/provider/pkg/resources" "github.com/pulumi/pulumi-azure-native/v2/provider/pkg/resources/customresources" + "github.com/pulumi/pulumi-azure-native/v2/provider/pkg/util" "github.com/pulumi/pulumi/sdk/v3/go/common/resource" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -424,6 +425,10 @@ func TestUsesCorrectAzureClient(t *testing.T) { } func TestAzcoreAzureClientUsesCorrectCloud(t *testing.T) { + if !util.EnableAzcoreBackend() { + t.Skip() + } + for expectedHost, cloudInstance := range map[string]cloud.Configuration{ "https://management.azure.com": cloud.AzurePublic, "https://management.chinacloudapi.cn": cloud.AzureChina, @@ -440,7 +445,7 @@ func TestAzcoreAzureClientUsesCorrectCloud(t *testing.T) { // Use reflection to get the value of the private 'host' field clientValue := reflect.ValueOf(client).Elem() hostField := clientValue.FieldByName("host") - require.True(t, hostField.IsValid(), "host field should be valid", expectedHost) + require.True(t, hostField.IsValid(), "host field should be valid (%s)", expectedHost) assert.Equal(t, expectedHost, hostField.String()) }