From 99bc353f1e6ce5a79bc48080b17639fc8db79d18 Mon Sep 17 00:00:00 2001 From: DavoudEshtehari <61173489+DavoudEshtehari@users.noreply.github.com> Date: Tue, 30 Aug 2022 14:25:21 -0700 Subject: [PATCH] Parallelize SSRP requests when MSF is specified (#1578) (#1708) --- .../Microsoft/Data/SqlClient/SNI/SNICommon.cs | 1 + .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 10 +- .../Data/SqlClient/SNI/SNITcpHandle.cs | 23 +- .../src/Microsoft/Data/SqlClient/SNI/SSRP.cs | 196 ++++++++++++++++-- .../SQL/InstanceNameTest/InstanceNameTest.cs | 124 ++++++++++- 5 files changed, 319 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index b57dc4f5f3..6980eb09f2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -108,6 +108,7 @@ internal class SNICommon internal const int ConnTimeoutError = 11; internal const int ConnNotUsableError = 19; internal const int InvalidConnStringError = 25; + internal const int ErrorLocatingServerInstance = 26; internal const int HandshakeFailureError = 31; internal const int InternalExceptionError = 35; internal const int ConnOpenFailedError = 40; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 501a68e401..f9a3c88fa3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -141,7 +141,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// /// IP address preference /// Used for DNS Cache - /// Used for DNS Cache + /// Used for DNS Cache /// SNI handle internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) @@ -263,7 +263,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr /// Should MultiSubnetFailover be used /// IP address preference /// Key for DNS Cache - /// Used for DNS Cache + /// Used for DNS Cache /// SNITCPHandle private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { @@ -285,12 +285,12 @@ private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire try { port = isAdminConnection ? - SSRP.GetDacPortByInstanceName(hostName, details.InstanceName) : - SSRP.GetPortByInstanceName(hostName, details.InstanceName); + SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) : + SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference); } catch (SocketException se) { - SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InvalidConnStringError, se); + SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.ErrorLocatingServerInstance, se); return null; } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 12f6370ecc..73f2c7e6e7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -146,9 +146,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel bool reportError = true; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port); - // We will always first try to connect with serverName as before and let the DNS server to resolve the serverName. - // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with cached IPs based on IPAddressPreference. - // The exceptions will be throw to upper level and be handled as before. + // We will always first try to connect with serverName as before and let DNS resolve the serverName. + // If DNS resolution fails, we will try with IPs in the DNS cache if they exist. We try with cached IPs based on IPAddressPreference. + // Exceptions will be thrown to the caller and be handled as before. try { if (parallel) @@ -280,7 +280,12 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i Task connectTask; Task serverAddrTask = Dns.GetHostAddressesAsync(hostName); - serverAddrTask.Wait(ts); + bool complete = serverAddrTask.Wait(ts); + + // DNS timed out - don't block + if (!complete) + return null; + IPAddress[] serverAddresses = serverAddrTask.Result; if (serverAddresses.Length > MaxParallelIpAddresses) @@ -324,7 +329,6 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i availableSocket = connectTask.Result; return availableSocket; - } // Connect to server with hostName and port. @@ -334,7 +338,14 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); - IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName); + Task serverAddrTask = Dns.GetHostAddressesAsync(serverName); + bool complete = serverAddrTask.Wait(timeout); + + // DNS timed out - don't block + if (!complete) + return null; + + IPAddress[] ipAddresses = serverAddrTask.Result; string IPv4String = null; string IPv6String = null; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 0147b29f17..afd0dde34d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -3,10 +3,13 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Net; using System.Net.Sockets; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.Data.SqlClient.SNI @@ -21,8 +24,11 @@ internal class SSRP /// /// SQL Sever Browser hostname /// instance name to find port number + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference /// port number for given instance name - internal static int GetPortByInstanceName(string browserHostName, string instanceName) + internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); @@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc byte[] responsePacket = null; try { - responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest); + responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference); } catch (SocketException se) { @@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// /// SQL Sever Browser hostname /// instance name to lookup DAC port + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference /// DAC port for given instance name - internal static int GetDacPortByInstanceName(string browserHostName, string instanceName) + internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest); + byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference); const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -131,14 +140,23 @@ private static byte[] CreateDacPortInfoRequest(string instanceName) return requestPacket; } + private class SsrpResult + { + public byte[] ResponsePacket; + public Exception Error; + } + /// /// Sends request to server, and receives response from server by UDP. /// /// UDP server hostname /// UDP server port /// request packet + /// Connection timer expiration + /// query all resolved IP addresses in parallel + /// IP address preference /// response packet from UDP server - private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket) + private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -146,28 +164,174 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re Debug.Assert(port >= 0 && port <= 65535, "Invalid port"); Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array"); - const int sendTimeOutMs = 1000; - const int receiveTimeOutMs = 1000; + bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address); - IPAddress address = null; - bool isIpAddress = IPAddress.TryParse(browserHostname, out address); + TimeSpan ts = default; + // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count + // The infinite Timeout is a function of ConnectionString Timeout=0 + if (long.MaxValue != timerExpire) + { + ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; + ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; + } - byte[] responsePacket = null; - using (UdpClient client = new UdpClient(!isIpAddress ? AddressFamily.InterNetwork : address.AddressFamily)) + IPAddress[] ipAddresses = null; + if (!isIpAddress) + { + Task serverAddrTask = Dns.GetHostAddressesAsync(browserHostname); + bool taskComplete; + try + { + taskComplete = serverAddrTask.Wait(ts); + } + catch (AggregateException ae) + { + throw ae.InnerException; + } + + // If DNS took too long, need to return instead of blocking + if (!taskComplete) + return null; + + ipAddresses = serverAddrTask.Result; + } + + Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); + + switch (ipPreference) { - Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, port); + case SqlConnectionIPAddressPreference.IPv4First: + { + SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel); + if (response4 != null && response4.ResponsePacket != null) + return response4.ResponsePacket; + + SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel); + if (response6 != null && response6.ResponsePacket != null) + return response6.ResponsePacket; + + // No responses so throw first error + if (response4 != null && response4.Error != null) + throw response4.Error; + else if (response6 != null && response6.Error != null) + throw response6.Error; + + break; + } + case SqlConnectionIPAddressPreference.IPv6First: + { + SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel); + if (response6 != null && response6.ResponsePacket != null) + return response6.ResponsePacket; + + SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel); + if (response4 != null && response4.ResponsePacket != null) + return response4.ResponsePacket; + + // No responses so throw first error + if (response6 != null && response6.Error != null) + throw response6.Error; + else if (response4 != null && response4.Error != null) + throw response4.Error; + + break; + } + default: + { + SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel); + if (response != null && response.ResponsePacket != null) + return response.ResponsePacket; + else if (response != null && response.Error != null) + throw response.Error; + + break; + } + } + + return null; + } + } + + /// + /// Sends request to server, and receives response from server by UDP. + /// + /// IP Addresses + /// UDP server port + /// request packet + /// query all resolved IP addresses in parallel + /// response packet from UDP server + private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel) + { + if (ipAddresses.Length == 0) + return null; + + if (allIPsInParallel) // Used for MultiSubnetFailover + { + List> tasks = new(ipAddresses.Length); + CancellationTokenSource cts = new CancellationTokenSource(); + for (int i = 0; i < ipAddresses.Length; i++) + { + IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port); + tasks.Add(Task.Factory.StartNew(() => SendUDPRequest(endPoint, requestPacket))); + } + + List> completedTasks = new(); + while (tasks.Count > 0) + { + int first = Task.WaitAny(tasks.ToArray()); + if (tasks[first].Result.ResponsePacket != null) + { + cts.Cancel(); + return tasks[first].Result; + } + else + { + completedTasks.Add(tasks[first]); + tasks.Remove(tasks[first]); + } + } + + Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0"); + + // All tasks failed. Return the error from the first failure. + return completedTasks[0].Result; + } + else + { + // If not parallel, use the first IP address provided + IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port); + return SendUDPRequest(endPoint, requestPacket); + } + } + + private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket) + { + const int sendTimeOutMs = 1000; + const int receiveTimeOutMs = 1000; + + SsrpResult result = new(); + + try + { + using (UdpClient client = new UdpClient(endPoint.AddressFamily)) + { + Task sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint); Task receiveTask = null; - + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info."); if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs)) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client."); - responsePacket = receiveTask.Result.Buffer; + result.ResponsePacket = receiveTask.Result.Buffer; } } - - return responsePacket; } + catch (Exception e) + { + result.Error = e; + } + + return result; } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index 3202636c3c..ceba949d55 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -2,7 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Net.Sockets; +using System.Text; using System.Threading.Tasks; using Xunit; @@ -13,7 +15,7 @@ public static class InstanceNameTest [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void ConnectToSQLWithInstanceNameTest() { - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString); + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); bool proceed = true; string dataSourceStr = builder.DataSource.Replace("tcp:", ""); @@ -26,24 +28,116 @@ public static void ConnectToSQLWithInstanceNameTest() if (proceed) { - using (SqlConnection connection = new SqlConnection(builder.ConnectionString)) + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + connection.Close(); + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer))] + [InlineData(true, SqlConnectionIPAddressPreference.IPv4First)] + [InlineData(true, SqlConnectionIPAddressPreference.IPv6First)] + [InlineData(true, SqlConnectionIPAddressPreference.UsePlatformDefault)] + [InlineData(false, SqlConnectionIPAddressPreference.IPv4First)] + [InlineData(false, SqlConnectionIPAddressPreference.IPv6First)] + [InlineData(false, SqlConnectionIPAddressPreference.UsePlatformDefault)] + public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailover, SqlConnectionIPAddressPreference ipPreference) + { + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString); + builder.MultiSubnetFailover = useMultiSubnetFailover; + builder.IPAddressPreference = ipPreference; + + + Assert.True(ParseDataSource(builder.DataSource, out string hostname, out _, out string instanceName)); + + if (IsBrowserAlive(hostname) && IsValidInstance(hostname, instanceName)) + { + builder.DataSource = hostname + "\\" + instanceName; + try + { + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + } + catch (Exception ex) + { + Assert.True(false, "Unexpected connection failure: " + ex.Message); + } + } + + builder.ConnectTimeout = 2; + instanceName = "invalidinstance3456"; + if (!IsValidInstance(hostname, instanceName)) + { + builder.DataSource = hostname + "\\" + instanceName; + try { + using SqlConnection connection = new(builder.ConnectionString); connection.Open(); - connection.Close(); + Assert.True(false, "Unexpected connection success against " + instanceName); + } + catch (Exception ex) + { + Assert.Contains("Error Locating Server/Instance Specified", ex.Message); } } } + private static bool ParseDataSource(string dataSource, out string hostname, out int port, out string instanceName) + { + hostname = string.Empty; + port = -1; + instanceName = string.Empty; + + if (dataSource.Contains(",") && dataSource.Contains("\\")) + return false; + + if (dataSource.Contains(":")) + { + dataSource = dataSource.Substring(dataSource.IndexOf(":") + 1); + } + + if (dataSource.Contains(",")) + { + if (!int.TryParse(dataSource.Substring(dataSource.LastIndexOf(",") + 1), out port)) + { + return false; + } + dataSource = dataSource.Substring(0, dataSource.IndexOf(",") - 1); + } + + if (dataSource.Contains("\\")) + { + instanceName = dataSource.Substring(dataSource.LastIndexOf("\\") + 1); + dataSource = dataSource.Substring(0, dataSource.LastIndexOf("\\")); + } + + hostname = dataSource; + + return hostname.Length > 0 && hostname.IndexOfAny(new char[] { '\\', ':', ',' }) == -1; + } + private static bool IsBrowserAlive(string browserHostname) + { + const byte ClntUcastEx = 0x03; + + byte[] responsePacket = QueryBrowser(browserHostname, new byte[] { ClntUcastEx }); + return responsePacket != null && responsePacket.Length > 0; + } + + private static bool IsValidInstance(string browserHostName, string instanceName) + { + byte[] request = CreateInstanceInfoRequest(instanceName); + byte[] response = QueryBrowser(browserHostName, request); + return response != null && response.Length > 0; + } + + private static byte[] QueryBrowser(string browserHostname, byte[] requestPacket) { const int DefaultBrowserPort = 1434; const int sendTimeout = 1000; const int receiveTimeout = 1000; - const byte ClntUcastEx = 0x03; - - byte[] requestPacket = new byte[] { ClntUcastEx }; byte[] responsePacket = null; - using (UdpClient client = new UdpClient(AddressFamily.InterNetwork)) + using (UdpClient client = new(AddressFamily.InterNetwork)) { try { @@ -56,7 +150,21 @@ private static bool IsBrowserAlive(string browserHostname) } catch { } } - return responsePacket != null && responsePacket.Length > 0; + + return responsePacket; + } + + private static byte[] CreateInstanceInfoRequest(string instanceName) + { + const byte ClntUcastInst = 0x04; + instanceName += char.MinValue; + int byteCount = Encoding.ASCII.GetByteCount(instanceName); + + byte[] requestPacket = new byte[byteCount + 1]; + requestPacket[0] = ClntUcastInst; + Encoding.ASCII.GetBytes(instanceName, 0, instanceName.Length, requestPacket, 1); + + return requestPacket; } } }