Skip to content

Commit

Permalink
Add IP address preference support for TCP connection (#1015)
Browse files Browse the repository at this point in the history
  • Loading branch information
karinazhou authored May 17, 2021
1 parent fee499f commit 561b535
Show file tree
Hide file tree
Showing 43 changed files with 894 additions and 107 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
<docs>
<members name="SqlConnectionIPAddressPreference">
<SqlConnectionIPAddressPreferenceNetfx>
<summary>
Specifies a value for IP address preference during a TCP connection.
</summary>
<remarks>
<format type="text/markdown"><![CDATA[
## Remarks
If `Multi Subnet Failover` or "Transparent Network IP Resolution" is set to `true`, this setting has no effect.
]]></format>
</remarks>
</SqlConnectionIPAddressPreferenceNetfx>
<SqlConnectionIPAddressPreference>
<summary>
Specifies a value for IP address preference during a TCP connection.
</summary>
<remarks>
<format type="text/markdown"><![CDATA[
## Remarks
If `Multi Subnet Failover` is set to `true`, this setting has no effect.
]]></format>
</remarks>
</SqlConnectionIPAddressPreference>
<IPv4First>
<summary>Connects using IPv4 address(es) first. If the connection fails, try IPv6 address(es), if provided. This is the default value.</summary>
<value>0</value>
</IPv4First>
<IPv6First>
<summary>Connect using IPv6 address(es) first. If the connection fails, try IPv4 address(es), if available.</summary>
<value>1</value>
</IPv6First>
<UsePlatformDefault>
<summary>Connects with IP addresses in the order the underlying platform or operating system provides them.</summary>
<value>2</value>
</UsePlatformDefault>
</members>
</docs>
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,17 @@ False
</remarks>
<exception cref="T:System.ArgumentNullException">To set the value to null, use <see cref="F:System.DBNull.Value" />.</exception>
</DataSource>
<IPAddressPreference>
<summary>Gets or sets the value of IP address preference.</summary>
<returns>Returns IP address preference.</returns>
<remarks>
<format type="text/markdown"><![CDATA[
## Remarks
If `Multi Subnet Failover` is set to `true`, this setting has no effect.
]]></format>
</remarks>
</IPAddressPreference>
<EnclaveAttestationUrl>
<summary>Gets or sets the enclave attestation Url to be used with enclave based Always Encrypted.</summary>
<value>The enclave attestation Url.</value>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,18 @@ public enum SqlConnectionAttestationProtocol
HGS = 3
}
#endif
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionIPAddressPreference.xml' path='docs/members[@name="SqlConnectionIPAddressPreference"]/SqlConnectionIPAddressPreference/*' />
public enum SqlConnectionIPAddressPreference
{
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionIPAddressPreference.xml' path='docs/members[@name="SqlConnectionIPAddressPreference"]/IPv4First/*' />
IPv4First = 0, // default

/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionIPAddressPreference.xml' path='docs/members[@name="SqlConnectionIPAddressPreference"]/IPv6First/*' />
IPv6First = 1,

/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionIPAddressPreference.xml' path='docs/members[@name="SqlConnectionIPAddressPreference"]/UsePlatformDefault/*' />
UsePlatformDefault = 2
}
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionCertificateStoreProvider.xml' path='docs/members[@name="SqlColumnEncryptionCertificateStoreProvider"]/SqlColumnEncryptionCertificateStoreProvider/*'/>
public partial class SqlColumnEncryptionCertificateStoreProvider : Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider
{
Expand Down Expand Up @@ -883,6 +895,10 @@ public SqlConnectionStringBuilder(string connectionString) { }
[System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
public string EnclaveAttestationUrl { get { throw null; } set { } }
#endif
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionStringBuilder.xml' path='docs/members[@name="SqlConnectionStringBuilder"]/IPAddressPreference/*'/>
[System.ComponentModel.DisplayNameAttribute("IP Address Preference")]
[System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
public Microsoft.Data.SqlClient.SqlConnectionIPAddressPreference IPAddressPreference { get { throw null; } set { } }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnectionStringBuilder.xml' path='docs/members[@name="SqlConnectionStringBuilder"]/Encrypt/*'/>
[System.ComponentModel.DisplayNameAttribute("Encrypt")]
[System.ComponentModel.RefreshPropertiesAttribute(System.ComponentModel.RefreshProperties.All)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ private unsafe struct SNI_CLIENT_CONSUMER_INFO
public TransparentNetworkResolutionMode transparentNetworkResolution;
public int totalTimeout;
public bool isAzureSqlServerEndpoint;
public SqlConnectionIPAddressPreference ipAddressPreference;
public SNI_DNSCache_Info DNSCacheInfo;
}

Expand Down Expand Up @@ -275,6 +276,7 @@ private static extern uint SNIOpenWrapper(
[In] SNIHandle pConn,
out IntPtr ppConn,
[MarshalAs(UnmanagedType.Bool)] bool fSync,
SqlConnectionIPAddressPreference ipPreference,
[In] ref SNI_DNSCache_Info pDNSCachedInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
Expand Down Expand Up @@ -341,7 +343,7 @@ internal static uint SNIInitialize()
return SNIInitialize(IntPtr.Zero);
}

internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SQLDNSInfo cachedDNSInfo)
internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo)
{
// initialize consumer info for MARS
Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info();
Expand All @@ -353,10 +355,11 @@ internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHan
native_cachedDNSInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
native_cachedDNSInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ref native_cachedDNSInfo);
return SNIOpenWrapper(ref native_consumerInfo, "session:", parent, out pConn, fSync, ipPreference, ref native_cachedDNSInfo);
}

internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache, bool fSync, int timeout, bool fParallel, SQLDNSInfo cachedDNSInfo)
internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string constring, ref IntPtr pConn, byte[] spnBuffer, byte[] instanceName, bool fOverrideCache,
bool fSync, int timeout, bool fParallel, SqlConnectionIPAddressPreference ipPreference, SQLDNSInfo cachedDNSInfo)
{
fixed (byte* pin_instanceName = &instanceName[0])
{
Expand All @@ -379,6 +382,7 @@ internal static unsafe uint SNIOpenSyncEx(ConsumerInfo consumerInfo, string cons
clientConsumerInfo.totalTimeout = SniOpenTimeOut;
clientConsumerInfo.isAzureSqlServerEndpoint = ADP.IsAzureSqlServerEndpoint(constring);

clientConsumerInfo.ipAddressPreference = ipPreference;
clientConsumerInfo.DNSCacheInfo.wszCachedFQDN = cachedDNSInfo?.FQDN;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv4 = cachedDNSInfo?.AddrIPv4;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.Reflection;
Expand Down Expand Up @@ -400,6 +401,110 @@ internal static SqlConnectionAttestationProtocol ConvertToAttestationProtocol(st

#endregion

#region <<IPAddressPreference Utility>>
/// <summary>
/// IP Address Preference.
/// </summary>
private readonly static Dictionary<string, SqlConnectionIPAddressPreference> s_preferenceNames = new(StringComparer.InvariantCultureIgnoreCase);

static DbConnectionStringBuilderUtil()
{
foreach (SqlConnectionIPAddressPreference item in Enum.GetValues(typeof(SqlConnectionIPAddressPreference)))
{
s_preferenceNames.Add(item.ToString(), item);
}
}

/// <summary>
/// Convert a string value to the corresponding IPAddressPreference.
/// </summary>
/// <param name="value">The string representation of the enumeration name to convert.</param>
/// <param name="result">When this method returns, `result` contains an object of type `SqlConnectionIPAddressPreference` whose value is represented by `value` if the operation succeeds.
/// If the parse operation fails, `result` contains the default value of the `SqlConnectionIPAddressPreference` type.</param>
/// <returns>`true` if the value parameter was converted successfully; otherwise, `false`.</returns>
internal static bool TryConvertToIPAddressPreference(string value, out SqlConnectionIPAddressPreference result)
{
if (!s_preferenceNames.TryGetValue(value, out result))
{
result = DbConnectionStringDefaults.IPAddressPreference;
return false;
}
return true;
}

/// <summary>
/// Verifies if the `value` is defined in the expected Enum.
/// </summary>
internal static bool IsValidIPAddressPreference(SqlConnectionIPAddressPreference value)
=> value == SqlConnectionIPAddressPreference.IPv4First
|| value == SqlConnectionIPAddressPreference.IPv6First
|| value == SqlConnectionIPAddressPreference.UsePlatformDefault;

internal static string IPAddressPreferenceToString(SqlConnectionIPAddressPreference value)
=> Enum.GetName(typeof(SqlConnectionIPAddressPreference), value);

internal static SqlConnectionIPAddressPreference ConvertToIPAddressPreference(string keyword, object value)
{
if (value is null)
{
return DbConnectionStringDefaults.IPAddressPreference; // IPv4First
}

if (value is string sValue)
{
// try again after remove leading & trailing whitespaces.
sValue = sValue.Trim();
if (TryConvertToIPAddressPreference(sValue, out SqlConnectionIPAddressPreference result))
{
return result;
}

// string values must be valid
throw ADP.InvalidConnectionOptionValue(keyword);
}
else
{
// the value is not string, try other options
SqlConnectionIPAddressPreference eValue;

if (value is SqlConnectionIPAddressPreference preference)
{
eValue = preference;
}
else if (value.GetType().IsEnum)
{
// explicitly block scenarios in which user tries to use wrong enum types, like:
// builder["SqlConnectionIPAddressPreference"] = EnvironmentVariableTarget.Process;
// workaround: explicitly cast non-SqlConnectionIPAddressPreference enums to int
throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), null);
}
else
{
try
{
// Enum.ToObject allows only integral and enum values (enums are blocked above), raising ArgumentException for the rest
eValue = (SqlConnectionIPAddressPreference)Enum.ToObject(typeof(SqlConnectionIPAddressPreference), value);
}
catch (ArgumentException e)
{
// to be consistent with the messages we send in case of wrong type usage, replace
// the error with our exception, and keep the original one as inner one for troubleshooting
throw ADP.ConvertFailed(value.GetType(), typeof(SqlConnectionIPAddressPreference), e);
}
}

if (IsValidIPAddressPreference(eValue))
{
return eValue;
}
else
{
throw ADP.InvalidEnumerationValue(typeof(SqlConnectionIPAddressPreference), (int)eValue);
}
}
}
#endregion

internal static bool IsValidApplicationIntentValue(ApplicationIntent value)
{
Debug.Assert(Enum.GetNames(typeof(ApplicationIntent)).Length == 2, "ApplicationIntent enum has changed, update needed");
Expand Down Expand Up @@ -728,6 +833,7 @@ internal static partial class DbConnectionStringDefaults
internal const SqlConnectionColumnEncryptionSetting ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
internal const string EnclaveAttestationUrl = _emptyString;
internal const SqlConnectionAttestationProtocol AttestationProtocol = SqlConnectionAttestationProtocol.NotSpecified;
internal const SqlConnectionIPAddressPreference IPAddressPreference = SqlConnectionIPAddressPreference.IPv4First;
}


Expand Down Expand Up @@ -765,6 +871,7 @@ internal static partial class DbConnectionStringKeywords
internal const string ColumnEncryptionSetting = "Column Encryption Setting";
internal const string EnclaveAttestationUrl = "Enclave Attestation Url";
internal const string AttestationProtocol = "Attestation Protocol";
internal const string IPAddressPreference = "IP Address Preference";

// common keywords (OleDb, OracleClient, SqlClient)
internal const string DataSource = "Data Source";
Expand Down Expand Up @@ -793,6 +900,9 @@ internal static class DbConnectionStringSynonyms
//internal const string ApplicationName = APP;
internal const string APP = "app";

// internal const string IPAddressPreference = IPADDRESSPREFERENCE;
internal const string IPADDRESSPREFERENCE = "IPAddressPreference";

//internal const string ApplicationIntent = APPLICATIONINTENT;
internal const string APPLICATIONINTENT = "ApplicationIntent";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,12 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="async">Asynchronous connection</param>
/// <param name="parallel">Attempt parallel connects</param>
/// <param name="isIntegratedSecurity"></param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand All @@ -284,7 +286,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
case DataSource.Protocol.Admin:
case DataSource.Protocol.None: // default to using tcp if no protocol is provided
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timerExpire, parallel, cachedFQDN, ref pendingDNSInfo);
sniHandle = CreateTcpHandle(details, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo);
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timerExpire, parallel);
Expand Down Expand Up @@ -374,10 +376,11 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="details">Data source</param>
/// <param name="timerExpire">Timer expiration</param>
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
// TCP Format:
// tcp:<host name>\<instance name>
Expand Down Expand Up @@ -415,7 +418,7 @@ private SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool
port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort;
}

return new SNITCPHandle(hostName, port, timerExpire, parallel, cachedFQDN, ref pendingDNSInfo);
return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo);
}


Expand Down
Loading

0 comments on commit 561b535

Please sign in to comment.