Skip to content

Commit

Permalink
DefaultAzureCredential probes IMDS before sending it a token request (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Apr 5, 2024
1 parent 0aded76 commit 8f702b9
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 84 deletions.
1 change: 1 addition & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
* `ManagedIdentityCredential` now specifies resource IDs correctly for Azure Container Instances

### Other Changes
* Increased `DefaultAzureCredential` reliability when authenticating a managed identity on an Azure VM

## 1.6.0-beta.2 (2024-02-06)

Expand Down
2 changes: 1 addition & 1 deletion sdk/azidentity/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "go",
"TagPrefix": "go/azidentity",
"Tag": "go/azidentity_4d7934c64a"
"Tag": "go/azidentity_087379b475"
}
47 changes: 2 additions & 45 deletions sdk/azidentity/default_azure_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ package azidentity

import (
"context"
"errors"
"os"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand Down Expand Up @@ -98,13 +96,13 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameWorkloadIdentity, err: err})
}

o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions}
o := &ManagedIdentityCredentialOptions{ClientOptions: options.ClientOptions, dac: true}
if ID, ok := os.LookupEnv(azureClientID); ok {
o.ID = ClientID(ID)
}
miCred, err := NewManagedIdentityCredential(o)
if err == nil {
creds = append(creds, &timeoutWrapper{mic: miCred, timeout: time.Second})
creds = append(creds, miCred)
} else {
errorMessages = append(errorMessages, credNameManagedIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameManagedIdentity, err: err})
Expand Down Expand Up @@ -165,44 +163,3 @@ func (d *defaultCredentialErrorReporter) GetToken(ctx context.Context, opts poli
}

var _ azcore.TokenCredential = (*defaultCredentialErrorReporter)(nil)

// timeoutWrapper prevents a potentially very long timeout when managed identity isn't available
type timeoutWrapper struct {
mic *ManagedIdentityCredential
// timeout applies to all auth attempts until one doesn't time out
timeout time.Duration
}

// GetToken wraps DefaultAzureCredential's initial managed identity auth attempt with a short timeout
// because managed identity may not be available and connecting to IMDS can take several minutes to time out.
func (w *timeoutWrapper) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
var tk azcore.AccessToken
var err error
// no need to synchronize around this value because it's written only within ChainedTokenCredential's critical section
if w.timeout > 0 {
c, cancel := context.WithTimeout(ctx, w.timeout)
defer cancel()
tk, err = w.mic.GetToken(c, opts)
if isAuthFailedDueToContext(err) {
err = newCredentialUnavailableError(credNameManagedIdentity, "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information")
} else {
// some managed identity implementation is available, so don't apply the timeout to future calls
w.timeout = 0
}
} else {
tk, err = w.mic.GetToken(ctx, opts)
}
return tk, err
}

// unwraps nested AuthenticationFailedErrors to get the root error
func isAuthFailedDueToContext(err error) bool {
for {
var authFailedErr *AuthenticationFailedError
if !errors.As(err, &authFailedErr) {
break
}
err = authFailedErr.err
}
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
}
102 changes: 70 additions & 32 deletions sdk/azidentity/default_azure_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/stretchr/testify/require"
)

func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) {
Expand Down Expand Up @@ -191,8 +192,8 @@ func TestDefaultAzureCredential_UserAssignedIdentity(t *testing.T) {
t.Fatal(err)
}
for _, c := range cred.chain.sources {
if w, ok := c.(*timeoutWrapper); ok {
if actual := w.mic.mic.id; actual != ID {
if w, ok := c.(*ManagedIdentityCredential); ok {
if actual := w.mic.id; actual != ID {
t.Fatalf(`expected "%s", got "%v"`, ID, actual)
}
return
Expand Down Expand Up @@ -239,6 +240,36 @@ func TestDefaultAzureCredential_Workload(t *testing.T) {
testGetTokenSuccess(t, cred)
}

func TestDefaultAzureCredential_IMDSLive(t *testing.T) {
if recording.GetRecordMode() != recording.PlaybackMode && !liveManagedIdentity.imds {
t.Skip("set IDENTITY_IMDS_AVAILABLE to run this test")
}
// unsetting environment variables to skip EnvironmentCredential and other managed identity sources
for _, k := range []string{azureTenantID, identityEndpoint, msiEndpoint} {
if v, set := os.LookupEnv(k); set {
require.NoError(t, os.Unsetenv(k))
defer os.Setenv(k, v)
}
}
co, stop := initRecording(t)
defer stop()
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: co})
require.NoError(t, err)
testGetTokenSuccess(t, cred)

t.Run("ClientID", func(t *testing.T) {
if recording.GetRecordMode() != recording.PlaybackMode && liveManagedIdentity.clientID == "" {
t.Skip("set IDENTITY_VM_USER_ASSIGNED_MI_CLIENT_ID to run this test")
}
t.Setenv(azureClientID, liveManagedIdentity.clientID)
co, stop := initRecording(t)
defer stop()
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: co})
require.NoError(t, err)
testGetTokenSuccess(t, cred)
})
}

// delayPolicy adds a delay to pipeline requests. Used to test timeout behavior.
type delayPolicy struct {
delay time.Duration
Expand All @@ -256,45 +287,52 @@ func (p *delayPolicy) Do(req *policy.Request) (resp *http.Response, err error) {
return req.Next()
}

func TestDefaultAzureCredential_timeoutWrapper(t *testing.T) {
timeout := 100 * time.Millisecond
dp := delayPolicy{2 * timeout}
mic, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{
func TestDefaultAzureCredential_IMDSTimeout(t *testing.T) {
// unsetting environment variables to skip EnvironmentCredential and other managed identity sources
for _, k := range []string{azureTenantID, identityEndpoint, msiEndpoint} {
if v, set := os.LookupEnv(k); set {
require.NoError(t, os.Unsetenv(k))
defer os.Setenv(k, v)
}
}

// AzureCLICredential returning an error ensures we see the ManagedIdentityCredential timeout error
datp := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = datp }()
defaultAzTokenProvider = mockAzTokenProviderFailure

// shorten the timeout to speed up this test
ipt := imdsProbeTimeout
defer func() { imdsProbeTimeout = ipt }()
imdsProbeTimeout = 100 * time.Millisecond

dp := delayPolicy{2 * imdsProbeTimeout}
chain, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
PerCallPolicies: []policy.Policy{&dp},
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: &mockSTS{},
},
})
if err != nil {
t.Fatal(err)
}
wrapper := timeoutWrapper{mic, timeout}
chain, err := NewChainedTokenCredential([]azcore.TokenCredential{&wrapper}, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
for i := 0; i < 2; i++ {
// expecting credentialUnavailableError because delay exceeds the wrapper's timeout
// expecting an error because managed identity times out and AzureCLICredential returns an error
_, err = chain.GetToken(context.Background(), testTRO)
if _, ok := err.(credentialUnavailable); !ok {
t.Fatalf("expected credentialUnavailable, got %T: %v", err, err)
}
require.ErrorContains(t, err, credNameManagedIdentity+": managed identity timed out")
}

// remove the delay so the credential can authenticate
// remove the delay so ManagedIdentityCredential can get a token from the fake STS
dp.delay = 0
tk, err := chain.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf(`got unexpected token "%s"`, tk.Token)
}
// now there should be no special timeout (using a different scope bypasses the cache, forcing a token request)
dp.delay = 3 * timeout
tk, err = chain.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"not-" + liveTestScope}})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
require.Equal(t, tokenValue, tk.Token)

// now there should be no timeout on token requests
dp.delay = 2 * imdsProbeTimeout
tk, err = chain.GetToken(context.Background(), policy.TokenRequestOptions{
// using a different scope forces a token request by bypassing the cache
Scopes: []string{"not-" + testTRO.Scopes[0]},
})
require.NoError(t, err)
require.Equal(t, tokenValue, tk.Token)
}
34 changes: 28 additions & 6 deletions sdk/azidentity/managed_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const (
serviceFabricAPIVersion = "2019-07-01-preview"
)

var imdsProbeTimeout = time.Second

type msiType int

const (
Expand All @@ -55,13 +57,12 @@ const (
msiTypeServiceFabric
)

// managedIdentityClient provides the base for authenticating in managed identity environments
// This type includes an runtime.Pipeline and TokenCredentialOptions.
type managedIdentityClient struct {
azClient *azcore.Client
msiType msiType
endpoint string
id ManagedIDKind
azClient *azcore.Client
endpoint string
id ManagedIDKind
msiType msiType
probeIMDS bool
}

type wrappedNumber json.Number
Expand Down Expand Up @@ -147,6 +148,7 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
c.msiType = msiTypeCloudShell
}
} else {
c.probeIMDS = options.dac
setIMDSRetryOptionDefaults(&cp.Retry)
}

Expand Down Expand Up @@ -180,6 +182,26 @@ func (c *managedIdentityClient) provideToken(ctx context.Context, params confide

// authenticate acquires an access token
func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (azcore.AccessToken, error) {
// no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
// and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
if c.probeIMDS {
cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
defer cancel()
req, err := runtime.NewRequest(cx, http.MethodGet, c.endpoint)
if err == nil {
_, err = c.azClient.Pipeline().Do(req)
}
if err != nil {
msg := err.Error()
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
}
return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
}
// send normal token requests from now on because something responded
c.probeIMDS = false
}

msg, err := c.createAuthRequest(ctx, id, scopes)
if err != nil {
return azcore.AccessToken{}, err
Expand Down
7 changes: 7 additions & 0 deletions sdk/azidentity/managed_identity_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ type ManagedIdentityCredentialOptions struct {
// instead of the hosting environment's default. The value may be the identity's client ID or resource ID, but note that
// some platforms don't accept resource IDs.
ID ManagedIDKind

// dac indicates whether the credential is part of DefaultAzureCredential. When true, and the environment doesn't have
// configuration for a specific managed identity API, the credential tries to determine whether IMDS is available before
// sending its first token request. It does this by sending a malformed request with a short timeout. Any response to that
// request is taken to mean IMDS is available, in which case the credential will send ordinary token requests thereafter
// with no special timeout. The purpose of this behavior is to prevent a very long timeout when IMDS isn't available.
dac bool
}

// ManagedIdentityCredential authenticates an Azure managed identity in any hosting environment supporting managed identities.
Expand Down

0 comments on commit 8f702b9

Please sign in to comment.