Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Release/2.0] Fix pooled connection re-use on access token expiry #639

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ internal static partial class ADP
internal static Task<bool> FalseTask => _falseTask ?? (_falseTask = Task.FromResult(false));

internal const CompareOptions DefaultCompareOptions = CompareOptions.IgnoreKanaType | CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase;

internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout;
internal const int InfiniteConnectionTimeout = 0; // infinite connection timeout identifier in seconds
internal const int MaxBufferAccessTokenExpiry = 600; // max duration for buffer in seconds

static private void TraceException(string trace, Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ virtual protected bool ReadyToPrepareTransaction
}
}

internal virtual bool IsAccessTokenExpired => false;

abstract protected void Activate(Transaction transaction);

internal void ActivateConnection(Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
_waitHandles.CreationSemaphore.Release(1);
}
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if(null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
{
SqlConnectionString opt = (SqlConnectionString)options;
SqlConnectionPoolKey key = (SqlConnectionPoolKey)poolKey;
SqlInternalConnection result = null;
SessionData recoverySessionData = null;

SqlConnection sqlOwningConnection = (SqlConnection)owningConnection;
Expand Down Expand Up @@ -131,8 +130,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return result;
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,8 @@ static internal Exception InvalidArgumentValue(string methodName)
internal const int DecimalMaxPrecision28 = 28; // there are some cases in Odbc where we need that ...
internal const int DefaultCommandTimeout = 30;
internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout;
internal const int InfiniteConnectionTimeout = 0; // infinite connection timeout identifier in seconds
internal const int MaxBufferAccessTokenExpiry = 600; // max duration for buffer in seconds
internal const float FailoverTimeoutStep = 0.08F; // fraction of timeout to use for fast failover connections
internal const float FailoverTimeoutStepForTnir = 0.125F; // Fraction of timeout to use in case of Transparent Network IP resolution.
internal const int MinimumTimeoutForTnirMs = 500; // The first login attempt in Transparent network IP Resolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ public ConnectionState State
}
}

internal virtual bool IsAccessTokenExpired => false;

abstract protected void Activate(SysTx.Transaction transaction);

internal void ActivateConnection(SysTx.Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
{
Marshal.ThrowExceptionForHR(releaseSemaphoreResult); // will only throw if (hresult < 0)
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if (null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// Connection Resiliency
private bool _sessionRecoveryRequested;
internal bool _sessionRecoveryAcknowledged;
internal SessionData _currentSessionData; // internal for use from TdsParser only, otehr should use CurrentSessionData property that will fix database and language
internal SessionData _currentSessionData; // internal for use from TdsParser only, other should use CurrentSessionData property that will fix database and language
private SessionData _recoverySessionData;

// Federated Authentication
Expand All @@ -131,13 +131,14 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
internal bool _federatedAuthenticationInfoRequested; // Keep this distinct from _federatedAuthenticationRequested, since some fedauth library types may not need more info
internal bool _federatedAuthenticationInfoReceived;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;

// Certificate auth calbacks.
//
ServerCertificateValidationCallback _serverCallback;
ClientCertificateRetrievalCallback _clientCallback;
SqlClientOriginalNetworkAddressInfo _originalNetworkAddressInfo;
Expand All @@ -146,6 +147,18 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa

private bool _serverSupportsDNSCaching = false;

/// <summary>
/// Returns buffer time allowed before access token expiry to continue using the access token.
/// </summary>
private int accessTokenExpirationBufferTime
{
get
{
return (ConnectionOptions.ConnectTimeout == ADP.InfiniteConnectionTimeout || ConnectionOptions.ConnectTimeout >= ADP.MaxBufferAccessTokenExpiry)
? ADP.MaxBufferAccessTokenExpiry : ConnectionOptions.ConnectTimeout;
}
}

/// <summary>
/// Get or set if SQLDNSCaching FeatureExtAck is supported by the server.
/// </summary>
Expand Down Expand Up @@ -808,6 +821,10 @@ protected override bool UnbindOnTransactionCompletion
}
}

/// <summary>
/// Validates if federated authentication is used, Access Token used by this connection is active for the value of 'accessTokenExpirationBufferTime'.
/// </summary>
internal override bool IsAccessTokenExpired => _federatedAuthenticationInfoRequested && DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime) < DateTime.UtcNow.AddSeconds(accessTokenExpirationBufferTime);

////////////////////////////////////////////////////////////////////////////////////////
// GENERAL METHODS
Expand Down Expand Up @@ -1321,10 +1338,10 @@ internal void ExecuteTransactionYukon(
ThreadHasParserLockForClose = false;
_parserLock.Release();
releaseConnectionLock = false;
}, 0);
}, ADP.InfiniteConnectionTimeout);
if (reconnectTask != null)
{
AsyncHelper.WaitForCompletion(reconnectTask, 0); // there is no specific timeout for BeginTransaction, uses ConnectTimeout
AsyncHelper.WaitForCompletion(reconnectTask, ADP.InfiniteConnectionTimeout); // there is no specific timeout for BeginTransaction, uses ConnectTimeout
internalTransaction.ConnectionHasBeenRestored = true;
return;
}
Expand Down Expand Up @@ -2538,9 +2555,6 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
// We want to refresh the token, if taking the lock on the authentication context is successful.
bool attemptRefreshTokenLocked = false;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken fedAuthToken = null;

if (_dbConnectionPool != null)
{
Debug.Assert(_dbConnectionPool.AuthenticationContexts != null);
Expand Down Expand Up @@ -2575,7 +2589,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
}
else if (_forceExpiryLocked)
{
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);
}
#endif

Expand All @@ -2589,11 +2603,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)

// Call the function which tries to acquire a lock over the authentication context before trying to update.
// If the lock could not be obtained, it will return false, without attempting to fetch a new token.
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);

// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and fedAuthToken should not be null.
// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and _fedAuthToken should not be null.
// If there was an exception in retrieving the new token, TryGetFedAuthTokenLocked should have thrown, so we won't be here.
Debug.Assert(!attemptRefreshTokenLocked || fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _fedAuthToken != null, "Either Lock should not have been obtained or _fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _newDbConnectionPoolAuthenticationContext != null, "Either Lock should not have been obtained or _newDbConnectionPoolAuthenticationContext should not be null.");

// Indicate in Bid Trace that we are successful with the update.
Expand All @@ -2610,8 +2624,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
if (dbConnectionPoolAuthenticationContext == null || attemptRefreshTokenUnLocked)
{
// Get the Federated Authentication Token.
fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
_fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(_fedAuthToken != null, "_fedAuthToken should not be null.");

if (_dbConnectionPool != null)
{
Expand All @@ -2622,18 +2636,19 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
else if (!attemptRefreshTokenLocked)
{
Debug.Assert(dbConnectionPoolAuthenticationContext != null, "dbConnectionPoolAuthenticationContext should not be null.");
Debug.Assert(fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_fedAuthToken == null, "_fedAuthToken should be null in this case.");
Debug.Assert(_newDbConnectionPoolAuthenticationContext == null, "_newDbConnectionPoolAuthenticationContext should be null.");

fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// If the code flow is here, then we are re-using the context from the cache for this connection attempt and not
// generating a new access token on this thread.
fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.expirationFileTime = dbConnectionPoolAuthenticationContext.ExpirationTime.ToFileTime();
}

Debug.Assert(fedAuthToken != null && fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(fedAuthToken);
Debug.Assert(_fedAuthToken != null && _fedAuthToken.accessToken != null, "_fedAuthToken and _fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(_fedAuthToken);
}

/// <summary>
Expand Down Expand Up @@ -2873,7 +2888,8 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (_routingInfo != null)
{
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) {
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId)
{
return;
}
}
Expand Down Expand Up @@ -3101,16 +3117,18 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream);
}

if (1 == data[0]) {
if (1 == data[0])
{
IsSQLDNSCachingSupported = true;
_cleanSQLDNSCaching = false;

if (_routingInfo != null)
{
IsDNSCachingBeforeRedirectSupported = true;
}
}
else {
else
{
// we receive the IsSupported whose value is 0
IsSQLDNSCachingSupported = false;
_cleanSQLDNSCaching = true;
Expand Down