Skip to content

Commit

Permalink
azcore fix: initialize request pipeline with configured cloud (#3802)
Browse files Browse the repository at this point in the history
Fixes #3795, hopefully.

The way Azure's azcore and azidentity SDKs are designed makes it hard to
test this, but reading [the source
here](https://github.com/Azure/azure-sdk-for-go/blob/sdk/azcore/v1.16.0/sdk/azcore/arm/runtime/pipeline.go#L60)
it looks like this change should fix the issue.
  • Loading branch information
thomas11 authored Dec 18, 2024
1 parent c57661c commit 1285abd
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
39 changes: 24 additions & 15 deletions provider/pkg/azure/client_azcore.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,15 @@ type azCoreClient struct {
updatePollingIntervalSeconds int64
}

func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, azureCloud cloud.Configuration, opts *arm.ClientOptions,
) (AzureClient, error) {
// Hook our logging up to the azcore logger.
log.SetListener(func(event log.Event, msg string) {
// Retry logging is very verbose and the number of the retry attempt is already contained
// in the response event.
if event != log.EventRetryPolicy {
logging.V(9).Infof("[azcore] %v: %s", event, msg)
}
})

func initPipelineOpts(azureCloud cloud.Configuration, opts *arm.ClientOptions) *arm.ClientOptions {
if opts == nil {
opts = &arm.ClientOptions{}
opts = &arm.ClientOptions{
ClientOptions: policy.ClientOptions{
Cloud: azureCloud,
},
}
}

// azcore logging will only happen at log level 9.
opts.Logging.IncludeBody = true

Expand All @@ -72,7 +67,6 @@ func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, a
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout, // 504
}

opts.Retry.ShouldRetry = func(resp *http.Response, err error) bool {
if err != nil {
return true
Expand All @@ -81,14 +75,29 @@ func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, a
if runtime.HasStatusCode(resp, retryableStatusCodes...) {
return true
}

if shouldRetryConflict(resp) {
return true
}

return false
}

return opts
}

// NewAzCoreClient creates a new AzureClient using the azcore SDK. For general use, leave userOpts
// nil to use the default options. If you do set it, make sure to set its ClientOptions.Cloud field.
func NewAzCoreClient(tokenCredential azcore.TokenCredential, userAgent string, azureCloud cloud.Configuration, userOpts *arm.ClientOptions,
) (AzureClient, error) {
// Hook our logging up to the azcore logger.
log.SetListener(func(event log.Event, msg string) {
// Retry logging is very verbose and the number of the retry attempt is already contained
// in the response event.
if event != log.EventRetryPolicy {
logging.V(9).Infof("[azcore] %v: %s", event, msg)
}
})

opts := initPipelineOpts(azureCloud, userOpts)
pipeline, err := armruntime.NewPipeline("pulumi-azure-native", version.Version, tokenCredential,
runtime.PipelineOptions{}, opts)
if err != nil {
Expand Down
74 changes: 74 additions & 0 deletions provider/pkg/azure/client_azcore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/url"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
Expand All @@ -20,6 +21,79 @@ import (
"github.com/stretchr/testify/require"
)

func TestInitPipelineOpts(t *testing.T) {
t.Run("retry delays", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.InDelta(t, 20*time.Second, opts.Retry.RetryDelay, 10.0)
assert.InDelta(t, 120*time.Second, opts.Retry.MaxRetryDelay, 30.0)
})

t.Run("cloud is public", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.Equal(t, cloud.AzurePublic, opts.ClientOptions.Cloud)
})

t.Run("cloud is usgov", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzureGovernment, nil)
assert.Equal(t, cloud.AzureGovernment, opts.ClientOptions.Cloud)
})

t.Run("cloud is china", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzureChina, nil)
assert.Equal(t, cloud.AzureChina, opts.ClientOptions.Cloud)
})

t.Run("should retry", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.NotNil(t, opts.Retry.ShouldRetry)
})

t.Run("retries on 408 timeout", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusRequestTimeout}, nil))
})

t.Run("retries on 409 conflict when another operation is in progress", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
header := http.Header{}
header.Add("x-ms-error-code", "AnotherOperationInProgress")
assert.True(t, opts.Retry.ShouldRetry(&http.Response{
StatusCode: http.StatusConflict,
Header: header,
}, nil))
})

t.Run("doesn't retry on 409 conflict when no other operation is in progress", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.False(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusConflict}, nil))
})

t.Run("retries on 429 too many requests", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusTooManyRequests}, nil))
})

t.Run("retries on 500 internal server error", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusInternalServerError}, nil))
})

t.Run("retries on 502 bad gateway", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusBadGateway}, nil))
})

t.Run("retries on 503 service unavailable", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusServiceUnavailable}, nil))
})

t.Run("retries on 504 gateway timeout", func(t *testing.T) {
opts := initPipelineOpts(cloud.AzurePublic, nil)
assert.True(t, opts.Retry.ShouldRetry(&http.Response{StatusCode: http.StatusGatewayTimeout}, nil))
})
}

func TestNormalizeLocationHeader(t *testing.T) {
const host = "https://management.azure.com"
const apiVersion = "2022-09-01"
Expand Down
7 changes: 6 additions & 1 deletion provider/pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/provider/crud"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/resources"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/resources/customresources"
"github.com/pulumi/pulumi-azure-native/v2/provider/pkg/util"
"github.com/pulumi/pulumi/sdk/v3/go/common/resource"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -424,6 +425,10 @@ func TestUsesCorrectAzureClient(t *testing.T) {
}

func TestAzcoreAzureClientUsesCorrectCloud(t *testing.T) {
if !util.EnableAzcoreBackend() {
t.Skip()
}

for expectedHost, cloudInstance := range map[string]cloud.Configuration{
"https://management.azure.com": cloud.AzurePublic,
"https://management.chinacloudapi.cn": cloud.AzureChina,
Expand All @@ -440,7 +445,7 @@ func TestAzcoreAzureClientUsesCorrectCloud(t *testing.T) {
// Use reflection to get the value of the private 'host' field
clientValue := reflect.ValueOf(client).Elem()
hostField := clientValue.FieldByName("host")
require.True(t, hostField.IsValid(), "host field should be valid", expectedHost)
require.True(t, hostField.IsValid(), "host field should be valid (%s)", expectedHost)

assert.Equal(t, expectedHost, hostField.String())
}
Expand Down

0 comments on commit 1285abd

Please sign in to comment.