diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs index 214d57ccb717..71ce594bd91e 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNITcpHandle.cs @@ -137,19 +137,20 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba } connectTask = ParallelConnectAsync(serverAddresses, port); + + if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) + { + ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); + return; + } + + _socket = connectTask.Result; } else { - connectTask = ConnectAsync(serverName, port); - } - - if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) - { - ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty); - return; + _socket = Connect(serverName, port, ts); } - - _socket = connectTask.Result; + if (_socket == null || !_socket.Connected) { if (_socket != null) @@ -182,31 +183,72 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba _status = TdsEnums.SNI_SUCCESS; } - private static async Task ConnectAsync(string serverName, int port) + private static Socket Connect(string serverName, int port, TimeSpan timeout) { - IPAddress[] addresses = await Dns.GetHostAddressesAsync(serverName).ConfigureAwait(false); - IPAddress targetAddrV4 = Array.Find(addresses, addr => (addr.AddressFamily == AddressFamily.InterNetwork)); - IPAddress targetAddrV6 = Array.Find(addresses, addr => (addr.AddressFamily == AddressFamily.InterNetworkV6)); - if (targetAddrV4 != null && targetAddrV6 != null) + IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); + IPAddress serverIPv4 = null; + IPAddress serverIPv6 = null; + foreach (IPAddress ipAdress in ipAddresses) { - return await ParallelConnectAsync(new IPAddress[] { targetAddrV4, targetAddrV6 }, port).ConfigureAwait(false); + if (ipAdress.AddressFamily == AddressFamily.InterNetwork) + { + serverIPv4 = ipAdress; + } + else if (ipAdress.AddressFamily == AddressFamily.InterNetworkV6) + { + serverIPv6 = ipAdress; + } } - else - { - IPAddress targetAddr = (targetAddrV4 != null) ? targetAddrV4 : targetAddrV6; - var socket = new Socket(targetAddr.AddressFamily, SocketType.Stream, ProtocolType.Tcp); + ipAddresses = new IPAddress[] { serverIPv4, serverIPv6 }; + Socket[] sockets = new Socket[2]; - try + CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(timeout); + void Cancel() + { + for (int i = 0; i < sockets.Length; ++i) { - await socket.ConnectAsync(targetAddr, port).ConfigureAwait(false); + try + { + if (sockets[i] != null && !sockets[i].Connected) + { + sockets[i].Dispose(); + sockets[i] = null; + } + } + catch { } } - catch + } + cts.Token.Register(Cancel); + + Socket availableSocket = null; + for (int i = 0; i < sockets.Length; ++i) + { + try { - socket.Dispose(); - throw; + if (ipAddresses[i] != null) + { + sockets[i] = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp); + sockets[i].Connect(ipAddresses[i], port); + if (sockets[i] != null) // sockets[i] can be null if cancel callback is executed during connect() + { + if (sockets[i].Connected) + { + availableSocket = sockets[i]; + break; + } + else + { + sockets[i].Dispose(); + sockets[i] = null; + } + } + } } - return socket; + catch { } } + + return availableSocket; } private static Task ParallelConnectAsync(IPAddress[] serverAddresses, int port) @@ -320,7 +362,7 @@ public override uint EnableSsl(uint options) try { - _sslStream.AuthenticateAsClientAsync(_targetServer).GetAwaiter().GetResult(); + _sslStream.AuthenticateAsClient(_targetServer); _sslOverTdsStream.FinishHandshake(); } catch (AuthenticationException aue) diff --git a/src/System.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs b/src/System.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs index a94d7530832f..c1ffe56fff24 100644 --- a/src/System.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs +++ b/src/System.Data.SqlClient/tests/FunctionalTests/SqlConnectionTest.cs @@ -2,16 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. - +using System.Collections.Generic; using System.Data.Common; +using System.Diagnostics; using System.Reflection; +using System.Threading; using Xunit; namespace System.Data.SqlClient.Tests { public class SqlConnectionBasicTests { - [Fact] public void ConnectionTest() { @@ -50,5 +51,97 @@ public void SqlConnectionDbProviderFactoryTest() Assert.Same(typeof(SqlClientFactory), factory.GetType()); Assert.Same(SqlClientFactory.Instance, factory); } + + [Fact] + public void ConnectionTimeoutTestWithThread() + { + int timeoutSec = 5; + string connStrNotAvailable = $"Server=tcp:fakeServer,1433;uid=fakeuser;pwd=fakepwd;Connection Timeout={timeoutSec}"; + + List list = new List(); + for (int i = 0; i < 10; ++i) + { + list.Add(new ConnectionWorker(connStrNotAvailable)); + } + + ConnectionWorker.Start(); + ConnectionWorker.Stop(); + + double theMax = 0; + foreach (ConnectionWorker w in list) + { + if (theMax < w.MaxTimeElapsed) + { + theMax = w.MaxTimeElapsed; + } + } + + int threshold = (timeoutSec + 1) * 1000; + Assert.True(theMax < threshold); + } + + public class ConnectionWorker + { + private static ManualResetEventSlim startEvent = new ManualResetEventSlim(false); + private static List workerList = new List(); + private ManualResetEventSlim doneEvent = new ManualResetEventSlim(false); + private double maxTimeElapsed; + private Thread thread; + private string connectionString; + + public ConnectionWorker(string connectionString) + { + workerList.Add(this); + this.connectionString = connectionString; + thread = new Thread(new ThreadStart(SqlConnectionOpen)); + thread.Start(); + } + + public double MaxTimeElapsed + { + get + { + return maxTimeElapsed; + } + } + + public static void Start() + { + startEvent.Set(); + } + + public static void Stop() + { + foreach (ConnectionWorker w in workerList) + { + w.doneEvent.Wait(); + } + } + + public void SqlConnectionOpen() + { + startEvent.Wait(); + + Stopwatch sw = new Stopwatch(); + using (SqlConnection con = new SqlConnection(connectionString)) + { + sw.Start(); + try + { + con.Open(); + } + catch { } + sw.Stop(); + } + + double elapsed = sw.Elapsed.TotalMilliseconds; + if (maxTimeElapsed < elapsed) + { + maxTimeElapsed = elapsed; + } + + doneEvent.Set(); + } + } } }