Skip to content

Commit

Permalink
Fix for SqlConnection failure when having multiple concurrent users (…
Browse files Browse the repository at this point in the history
…#25620)
  • Loading branch information
Gene Lee committed Jan 10, 2018
1 parent b8b87d5 commit af9a6e8
Showing 1 changed file with 68 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,20 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}

connectTask = ParallelConnectAsync(serverAddresses, port);

if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts)))
{
ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty);
return;
}

_socket = connectTask.Result;
}
else
{
connectTask = ConnectAsync(serverName, port);
}

if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts)))
{
ReportTcpSNIError(0, SNICommon.ConnOpenFailedError, string.Empty);
return;
_socket = Connect(serverName, port, ts);
}

_socket = connectTask.Result;

if (_socket == null || !_socket.Connected)
{
if (_socket != null)
Expand Down Expand Up @@ -182,31 +183,72 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_status = TdsEnums.SNI_SUCCESS;
}

private static async Task<Socket> ConnectAsync(string serverName, int port)
private static Socket Connect(string serverName, int port, TimeSpan timeout)
{
IPAddress[] addresses = await Dns.GetHostAddressesAsync(serverName).ConfigureAwait(false);
IPAddress targetAddrV4 = Array.Find(addresses, addr => (addr.AddressFamily == AddressFamily.InterNetwork));
IPAddress targetAddrV6 = Array.Find(addresses, addr => (addr.AddressFamily == AddressFamily.InterNetworkV6));
if (targetAddrV4 != null && targetAddrV6 != null)
IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
IPAddress serverIPv4 = null;
IPAddress serverIPv6 = null;
foreach (IPAddress ipAdress in ipAddresses)
{
return await ParallelConnectAsync(new IPAddress[] { targetAddrV4, targetAddrV6 }, port).ConfigureAwait(false);
if (ipAdress.AddressFamily == AddressFamily.InterNetwork)
{
serverIPv4 = ipAdress;
}
else if (ipAdress.AddressFamily == AddressFamily.InterNetworkV6)
{
serverIPv6 = ipAdress;
}
}
else
{
IPAddress targetAddr = (targetAddrV4 != null) ? targetAddrV4 : targetAddrV6;
var socket = new Socket(targetAddr.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
ipAddresses = new IPAddress[] { serverIPv4, serverIPv6 };
Socket[] sockets = new Socket[2];

try
CancellationTokenSource cts = new CancellationTokenSource();
cts.CancelAfter(timeout);
void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
{
await socket.ConnectAsync(targetAddr, port).ConfigureAwait(false);
try
{
if (sockets[i] != null && !sockets[i].Connected)
{
sockets[i].Dispose();
sockets[i] = null;
}
}
catch { }
}
catch
}
cts.Token.Register(Cancel);

Socket availableSocket = null;
for (int i = 0; i < sockets.Length; ++i)
{
try
{
socket.Dispose();
throw;
if (ipAddresses[i] != null)
{
sockets[i] = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp);
sockets[i].Connect(ipAddresses[i], port);
if (sockets[i] != null) // sockets[i] can be null if cancel callback is executed during connect()
{
if (sockets[i].Connected)
{
availableSocket = sockets[i];
break;
}
else
{
sockets[i].Dispose();
sockets[i] = null;
}
}
}
}
return socket;
catch { }
}

return availableSocket;
}

private static Task<Socket> ParallelConnectAsync(IPAddress[] serverAddresses, int port)
Expand Down Expand Up @@ -320,7 +362,7 @@ public override uint EnableSsl(uint options)

try
{
_sslStream.AuthenticateAsClientAsync(_targetServer).GetAwaiter().GetResult();
_sslStream.AuthenticateAsClient(_targetServer);
_sslOverTdsStream.FinishHandshake();
}
catch (AuthenticationException aue)
Expand Down

0 comments on commit af9a6e8

Please sign in to comment.