Skip to content

Commit

Permalink
Parallelize SSRP requests when MSF is specified (#1578) (#1708)
Browse files Browse the repository at this point in the history
DavoudEshtehari authored Aug 30, 2022
1 parent 5cca4a7 commit 99bc353
Showing 5 changed files with 319 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -108,6 +108,7 @@ internal class SNICommon
internal const int ConnTimeoutError = 11;
internal const int ConnNotUsableError = 19;
internal const int InvalidConnStringError = 25;
internal const int ErrorLocatingServerInstance = 26;
internal const int HandshakeFailureError = 31;
internal const int InternalExceptionError = 35;
internal const int ConnOpenFailedError = 40;
Original file line number Diff line number Diff line change
@@ -141,7 +141,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
/// <param name="isIntegratedSecurity"></param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
@@ -263,7 +263,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
@@ -285,12 +285,12 @@ private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire
try
{
port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName);
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference);
}
catch (SocketException se)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InvalidConnStringError, se);
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.ErrorLocatingServerInstance, se);
return null;
}
}
Original file line number Diff line number Diff line change
@@ -146,9 +146,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel
bool reportError = true;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with cached IPs based on IPAddressPreference.
// The exceptions will be throw to upper level and be handled as before.
// We will always first try to connect with serverName as before and let DNS resolve the serverName.
// If DNS resolution fails, we will try with IPs in the DNS cache if they exist. We try with cached IPs based on IPAddressPreference.
// Exceptions will be thrown to the caller and be handled as before.
try
{
if (parallel)
@@ -280,7 +280,12 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
Task<Socket> connectTask;

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
serverAddrTask.Wait(ts);
bool complete = serverAddrTask.Wait(ts);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] serverAddresses = serverAddrTask.Result;

if (serverAddresses.Length > MaxParallelIpAddresses)
@@ -324,7 +329,6 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i

availableSocket = connectTask.Result;
return availableSocket;

}

// Connect to server with hostName and port.
@@ -334,7 +338,14 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
bool complete = serverAddrTask.Wait(timeout);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] ipAddresses = serverAddrTask.Result;

string IPv4String = null;
string IPv6String = null;
Original file line number Diff line number Diff line change
@@ -3,10 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
@@ -21,8 +24,11 @@ internal class SSRP
/// </summary>
/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to find port number</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>port number for given instance name</returns>
internal static int GetPortByInstanceName(string browserHostName, string instanceName)
internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");
@@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
byte[] responsePacket = null;
try
{
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest);
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference);
}
catch (SocketException se)
{
@@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
/// </summary>
/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to lookup DAC port</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>DAC port for given instance name</returns>
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName)
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");

byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference);

const byte SvrResp = 0x05;
const byte ProtocolVersion = 0x01;
@@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
return requestPacket;
}

private class SsrpResult
{
public byte[] ResponsePacket;
public Exception Error;
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="browserHostname">UDP server hostname</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>response packet from UDP server</returns>
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket)
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
using (TrySNIEventScope.Create(nameof(SSRP)))
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostname), "browserhostname should not be null, empty, or whitespace");
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");

const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;
bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);

IPAddress address = null;
bool isIpAddress = IPAddress.TryParse(browserHostname, out address);
TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
// The infinite Timeout is a function of ConnectionString Timeout=0
if (long.MaxValue != timerExpire)
{
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

byte[] responsePacket = null;
using (UdpClient client = new UdpClient(!isIpAddress ? AddressFamily.InterNetwork : address.AddressFamily))
IPAddress[] ipAddresses = null;
if (!isIpAddress)
{
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
bool taskComplete;
try
{
taskComplete = serverAddrTask.Wait(ts);
}
catch (AggregateException ae)
{
throw ae.InnerException;
}

// If DNS took too long, need to return instead of blocking
if (!taskComplete)
return null;

ipAddresses = serverAddrTask.Result;
}

Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

switch (ipPreference)
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, port);
case SqlConnectionIPAddressPreference.IPv4First:
{
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

// No responses so throw first error
if (response4 != null && response4.Error != null)
throw response4.Error;
else if (response6 != null && response6.Error != null)
throw response6.Error;

break;
}
case SqlConnectionIPAddressPreference.IPv6First:
{
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

// No responses so throw first error
if (response6 != null && response6.Error != null)
throw response6.Error;
else if (response4 != null && response4.Error != null)
throw response4.Error;

break;
}
default:
{
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
if (response != null && response.ResponsePacket != null)
return response.ResponsePacket;
else if (response != null && response.Error != null)
throw response.Error;

break;
}
}

return null;
}
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="ipAddresses">IP Addresses</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <returns>response packet from UDP server</returns>
private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel)
{
if (ipAddresses.Length == 0)
return null;

if (allIPsInParallel) // Used for MultiSubnetFailover
{
List<Task<SsrpResult>> tasks = new(ipAddresses.Length);
CancellationTokenSource cts = new CancellationTokenSource();
for (int i = 0; i < ipAddresses.Length; i++)
{
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
}

List<Task<SsrpResult>> completedTasks = new();
while (tasks.Count > 0)
{
int first = Task.WaitAny(tasks.ToArray());
if (tasks[first].Result.ResponsePacket != null)
{
cts.Cancel();
return tasks[first].Result;
}
else
{
completedTasks.Add(tasks[first]);
tasks.Remove(tasks[first]);
}
}

Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0");

// All tasks failed. Return the error from the first failure.
return completedTasks[0].Result;
}
else
{
// If not parallel, use the first IP address provided
IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port);
return SendUDPRequest(endPoint, requestPacket);
}
}

private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket)
{
const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;

SsrpResult result = new();

try
{
using (UdpClient client = new UdpClient(endPoint.AddressFamily))
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint);
Task<UdpReceiveResult> receiveTask = null;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info.");
if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client.");
responsePacket = receiveTask.Result.Buffer;
result.ResponsePacket = receiveTask.Result.Buffer;
}
}

return responsePacket;
}
catch (Exception e)
{
result.Error = e;
}

return result;
}
}
}
Original file line number Diff line number Diff line change
@@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using Xunit;

@@ -13,7 +15,7 @@ public static class InstanceNameTest
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString);
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

bool proceed = true;
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
@@ -26,24 +28,116 @@ public static void ConnectToSQLWithInstanceNameTest()

if (proceed)
{
using (SqlConnection connection = new SqlConnection(builder.ConnectionString))
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
connection.Close();
}
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer))]
[InlineData(true, SqlConnectionIPAddressPreference.IPv4First)]
[InlineData(true, SqlConnectionIPAddressPreference.IPv6First)]
[InlineData(true, SqlConnectionIPAddressPreference.UsePlatformDefault)]
[InlineData(false, SqlConnectionIPAddressPreference.IPv4First)]
[InlineData(false, SqlConnectionIPAddressPreference.IPv6First)]
[InlineData(false, SqlConnectionIPAddressPreference.UsePlatformDefault)]
public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailover, SqlConnectionIPAddressPreference ipPreference)
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);
builder.MultiSubnetFailover = useMultiSubnetFailover;
builder.IPAddressPreference = ipPreference;


Assert.True(ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName));

if (IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName))
{
builder.DataSource = hostname + "\\" + instanceName;
try
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
}
catch (Exception ex)
{
Assert.True(false, "Unexpected connection failure: " + ex.Message);
}
}

builder.ConnectTimeout = 2;
instanceName = "invalidinstance3456";
if (!IsValidInstance(hostname, instanceName))
{
builder.DataSource = hostname + "\\" + instanceName;
try
{
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
connection.Close();
Assert.True(false, "Unexpected connection success against " + instanceName);
}
catch (Exception ex)
{
Assert.Contains("Error Locating Server/Instance Specified", ex.Message);
}
}
}

private static bool ParseDataSource(string dataSource, out string hostname, out int port, out string instanceName)
{
hostname = string.Empty;
port = -1;
instanceName = string.Empty;

if (dataSource.Contains(",") && dataSource.Contains("\\"))
return false;

if (dataSource.Contains(":"))
{
dataSource = dataSource.Substring(dataSource.IndexOf(":") + 1);
}

if (dataSource.Contains(","))
{
if (!int.TryParse(dataSource.Substring(dataSource.LastIndexOf(",") + 1), out port))
{
return false;
}
dataSource = dataSource.Substring(0, dataSource.IndexOf(",") - 1);
}

if (dataSource.Contains("\\"))
{
instanceName = dataSource.Substring(dataSource.LastIndexOf("\\") + 1);
dataSource = dataSource.Substring(0, dataSource.LastIndexOf("\\"));
}

hostname = dataSource;

return hostname.Length > 0 && hostname.IndexOfAny(new char[] { '\\', ':', ',' }) == -1;
}

private static bool IsBrowserAlive(string browserHostname)
{
const byte ClntUcastEx = 0x03;

byte[] responsePacket = QueryBrowser(browserHostname, new byte[] { ClntUcastEx });
return responsePacket != null && responsePacket.Length > 0;
}

private static bool IsValidInstance(string browserHostName, string instanceName)
{
byte[] request = CreateInstanceInfoRequest(instanceName);
byte[] response = QueryBrowser(browserHostName, request);
return response != null && response.Length > 0;
}

private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket)
{
const int DefaultBrowserPort = 1434;
const int sendTimeout = 1000;
const int receiveTimeout = 1000;
const byte ClntUcastEx = 0x03;

byte[] requestPacket = new byte[] { ClntUcastEx };
byte[] responsePacket = null;
using (UdpClient client = new UdpClient(AddressFamily.InterNetwork))
using (UdpClient client = new(AddressFamily.InterNetwork))
{
try
{
@@ -56,7 +150,21 @@ private static bool IsBrowserAlive(string browserHostname)
}
catch { }
}
return responsePacket != null && responsePacket.Length > 0;

return responsePacket;
}

private static byte[] CreateInstanceInfoRequest(string instanceName)
{
const byte ClntUcastInst = 0x04;
instanceName += char.MinValue;
int byteCount = Encoding.ASCII.GetByteCount(instanceName);

byte[] requestPacket = new byte[byteCount + 1];
requestPacket[0] = ClntUcastInst;
Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 1);

return requestPacket;
}
}
}

0 comments on commit 99bc353

Please sign in to comment.