From 63a897a5a2361d6c5a3258b48341fff0f409dbae Mon Sep 17 00:00:00 2001 From: Davoud Eshtehari Date: Thu, 13 Apr 2023 11:43:27 -0700 Subject: [PATCH] Fix | Throttling of token requests by calling AcquireTokenSilent (#1925) * Address throttling of token requests by calling AcquireTokenSilent in Integrated/Password flows when the account is already cached. Addresses issue #1915 Co-authored-by: Lawrence Cheung <31262254+lcheunglci@users.noreply.github.com> Co-authored-by: DavoudEshtehari <61173489+DavoudEshtehari@users.noreply.github.com> # Conflicts: # src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs --- .../ActiveDirectoryAuthenticationProvider.cs | 197 +++++++++++------- tools/props/Versions.props | 2 +- 2 files changed, 127 insertions(+), 72 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index a8fdf219d3..64caf32060 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -4,7 +4,10 @@ using System; using System.Collections.Concurrent; -using System.Security; +using System.Linq; +using System.Runtime.Caching; +using System.Security.Cryptography; +using System.Text; using System.Threading; using System.Threading.Tasks; using Azure.Core; @@ -24,6 +27,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro /// private static ConcurrentDictionary s_pcaMap = new ConcurrentDictionary(); + private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider)); + private static readonly int s_accountPwCacheTtlInHours = 2; private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; @@ -171,7 +176,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } - AuthenticationResult result; + AuthenticationResult result = null; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) { AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); @@ -207,86 +212,82 @@ public override async Task AcquireTokenAsync(SqlAuthenti if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - if (!string.IsNullOrEmpty(parameters.UserId)) - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - else - { - result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) - { - SecureString password = new SecureString(); - foreach (char c in parameters.Password) - password.AppendChar(c); - password.MakeReadOnly(); - - result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token) - .ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); - } - else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || - parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) - { - // Fetch available accounts from 'app' instance - System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); - IAccount account = default; - if (accounts.MoveNext()) + if (null == result) { if (!string.IsNullOrEmpty(parameters.UserId)) { - do - { - IAccount currentVal = accounts.Current; - if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) - { - account = currentVal; - break; - } - } - while (accounts.MoveNext()); + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } else { - account = accounts.Current; + result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); + } + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) + { + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (null != previousPw && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + currPwHash.SequenceEqual(previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); } - if (null != account) + if (null == result) { - try + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, parameters.Password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours))) { - // If 'account' is available in 'app', we use the same to acquire token silently. - // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent - result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); - } - catch (MsalUiRequiredException) - { - // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, - // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), - // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + s_accountPwCache.Remove(pwCacheKey); + s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)); } + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); } - else + } + else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || + parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) + { + try + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + catch (MsalUiRequiredException) + { + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + if (null == result) { // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } @@ -299,8 +300,49 @@ public override async Task AcquireTokenAsync(SqlAuthenti return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); } - private async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts) + private static async Task TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters, + string[] scopes, CancellationTokenSource cts) + { + AuthenticationResult result = null; + + // Fetch available accounts from 'app' instance + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + + IAccount account = default; + if (accounts.MoveNext()) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + do + { + IAccount currentVal = accounts.Current; + if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) + { + account = currentVal; + break; + } + } + while (accounts.MoveNext()); + } + else + { + account = accounts.Current; + } + } + + if (null != account) + { + // If 'account' is available in 'app', we use the same to acquire token silently. + // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return result; + } + + private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, + SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func deviceCodeFlowCallback) { try { @@ -319,11 +361,11 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( */ ctsInteractive.CancelAfter(180000); #endif - if (_customWebUI != null) + if (customWebUI != null) { return await app.AcquireTokenInteractive(scopes) .WithCorrelationId(connectionId) - .WithCustomWebUi(_customWebUI) + .WithCustomWebUi(customWebUI) .WithLoginHint(userId) .ExecuteAsync(ctsInteractive.Token) .ConfigureAwait(false); @@ -357,7 +399,7 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( else { AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, - deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)) + deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)) .WithCorrelationId(connectionId) .ExecuteAsync(cancellationToken: cts.Token) .ConfigureAwait(false); @@ -410,6 +452,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p return clientApplicationInstance; } + private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters) + { + return parameters.Authority + "+" + parameters.UserId; + } + + private static byte[] GetHash(string input) + { + byte[] unhashedBytes = Encoding.Unicode.GetBytes(input); + SHA256 sha256 = SHA256.Create(); + byte[] hashedBytes = sha256.ComputeHash(unhashedBytes); + return hashedBytes; + } + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) { IPublicClientApplication publicClientApplication; diff --git a/tools/props/Versions.props b/tools/props/Versions.props index 611ed6aa60..54477cb9ec 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -18,7 +18,7 @@ 1.3.0 - 4.22.0 + 4.47.2 6.8.0 6.8.0 4.5.1