diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 5eed2f7177..222b979ca8 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -227,6 +227,7 @@ private SqlConnection(SqlConnection connection) } _accessToken = connection._accessToken; + _accessTokenCallback = connection._accessTokenCallback; CacheConnectionStringProperties(); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index 8fca4ce207..652cff4f30 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -429,6 +429,7 @@ private SqlConnection(SqlConnection connection) _credential = new SqlCredential(connection._credential.UserId, password); } _accessToken = connection._accessToken; + _accessTokenCallback = connection._accessTokenCallback; _serverCertificateValidationCallback = connection._serverCertificateValidationCallback; _clientCertificateRetrievalCallback = connection._clientCertificateRetrievalCallback; _originalNetworkAddressInfo = connection._originalNetworkAddressInfo; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/CloneTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/CloneTests.cs index f5deb6c62c..1c8efc4456 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/CloneTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/CloneTests.cs @@ -4,6 +4,7 @@ using System; using System.Data; +using System.Threading.Tasks; using Xunit; namespace Microsoft.Data.SqlClient.Tests @@ -18,10 +19,19 @@ public void CloneSqlConnection() builder.ConnectTimeout = 1; builder.InitialCatalog = "northwinddb"; SqlConnection connection = new SqlConnection(builder.ConnectionString); + connection.AccessToken = Guid.NewGuid().ToString(); SqlConnection clonedConnection = (connection as ICloneable).Clone() as SqlConnection; Assert.Equal(connection.ConnectionString, clonedConnection.ConnectionString); Assert.Equal(connection.ConnectionTimeout, clonedConnection.ConnectionTimeout); + Assert.Equal(connection.AccessToken, clonedConnection.AccessToken); + Assert.NotEqual(connection, clonedConnection); + + connection = new SqlConnection(builder.ConnectionString); + connection.AccessTokenCallback = (ctx, token) => + Task.FromResult(new SqlAuthenticationToken(Guid.NewGuid().ToString(), DateTimeOffset.MaxValue)); + clonedConnection = (connection as ICloneable).Clone() as SqlConnection; + Assert.Equal(connection.AccessTokenCallback, clonedConnection.AccessTokenCallback); Assert.NotEqual(connection, clonedConnection); }