diff --git a/src/System.Data.Common/ref/System.Data.Common.cs b/src/System.Data.Common/ref/System.Data.Common.cs index 854d72c0bb4c..a0ec1d683847 100644 --- a/src/System.Data.Common/ref/System.Data.Common.cs +++ b/src/System.Data.Common/ref/System.Data.Common.cs @@ -2211,6 +2211,13 @@ protected DbProviderFactory() { } public virtual System.Data.Common.DbDataSourceEnumerator CreateDataSourceEnumerator() { throw null; } public virtual System.Data.Common.DbParameter CreateParameter() { throw null; } } + public static partial class DbProviderFactories + { + public static DbProviderFactory GetFactory(string providerInvariantName) { throw null; } + public static DbProviderFactory GetFactory(DataRow providerRow) { throw null; } + public static DbProviderFactory GetFactory(DbConnection connection) { throw null; } + public static DataTable GetFactoryClasses() { throw null; } + } [System.AttributeUsageAttribute((System.AttributeTargets)(128), AllowMultiple = false, Inherited = true)] public sealed partial class DbProviderSpecificTypePropertyAttribute : System.Attribute { diff --git a/src/System.Data.Common/ref/System.Data.Common.csproj b/src/System.Data.Common/ref/System.Data.Common.csproj index f3525231a91f..6dc2f123b3be 100644 --- a/src/System.Data.Common/ref/System.Data.Common.csproj +++ b/src/System.Data.Common/ref/System.Data.Common.csproj @@ -12,6 +12,9 @@ + + + diff --git a/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs b/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs new file mode 100644 index 000000000000..94df2006ad59 --- /dev/null +++ b/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// ------------------------------------------------------------------------------ +// Changes to this file must follow the http://aka.ms/api-review process. +// ------------------------------------------------------------------------------ + +using System.Collections.Generic; + +namespace System.Data.Common +{ + public static partial class DbProviderFactories + { + public static void RegisterFactory(string providerInvariantName, string factoryTypeAssemblyQualifiedName) { throw null; } + public static void RegisterFactory(string providerInvariantName, Type factoryType) { throw null; } + public static void RegisterFactory(string providerInvariantName, DbProviderFactory factory) { throw null; } + public static bool TryGetFactory(string providerInvariantName, out DbProviderFactory factory) { throw null; } + public static bool UnregisterFactory(string providerInvariantName) { throw null; } + public static IEnumerable GetProviderInvariantNames() { throw null; } + } +} diff --git a/src/System.Data.Common/src/Resources/Strings.resx b/src/System.Data.Common/src/Resources/Strings.resx index 8ac6107e7fc7..f13b9de2d8b7 100644 --- a/src/System.Data.Common/src/Resources/Strings.resx +++ b/src/System.Data.Common/src/Resources/Strings.resx @@ -512,4 +512,9 @@ Cannot remove this column, because it is part of an expression: {0} = {1}. The rowOrder value={0} has been found twice for table named '{1}'. Cannot find ElementType name='{0}'. + The specified invariant name '{0}' wasn't found in the list of registered .NET Data Providers. + The requested .NET Data Provider's implementation does not have an Instance field of a System.Data.Common.DbProviderFactory derived type. + The registered .NET Data Provider's DbProviderFactory implementation type '{0}' couldn't be loaded. + The missing .NET Data Provider's assembly qualified name is required. + The type '{0}' doesn't inherit from DbProviderFactory. diff --git a/src/System.Data.Common/src/System.Data.Common.csproj b/src/System.Data.Common/src/System.Data.Common.csproj index f510044d987d..5e3d53dee503 100644 --- a/src/System.Data.Common/src/System.Data.Common.csproj +++ b/src/System.Data.Common/src/System.Data.Common.csproj @@ -57,7 +57,9 @@ - + + Component + @@ -78,10 +80,14 @@ - + + Component + - + + Component + @@ -91,9 +97,13 @@ - + + Component + - + + Component + @@ -135,7 +145,9 @@ - + + Component + @@ -160,7 +172,9 @@ - + + Component + @@ -170,9 +184,15 @@ - - - + + Component + + + Component + + + Component + System\Data\Common\DbConnectionOptions.Common.cs @@ -183,7 +203,9 @@ - + + Component + @@ -195,6 +217,7 @@ + diff --git a/src/System.Data.Common/src/System/Data/Common/DbConnection.cs b/src/System.Data.Common/src/System/Data/Common/DbConnection.cs index 3880a544ac7d..766bb516e7b9 100644 --- a/src/System.Data.Common/src/System/Data/Common/DbConnection.cs +++ b/src/System.Data.Common/src/System/Data/Common/DbConnection.cs @@ -37,6 +37,8 @@ protected DbConnection() : base() /// protected virtual DbProviderFactory DbProviderFactory => null; + internal DbProviderFactory ProviderFactory => DbProviderFactory; + [Browsable(false)] public abstract string ServerVersion { get; } diff --git a/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs b/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs new file mode 100644 index 000000000000..bc6c482faaf4 --- /dev/null +++ b/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs @@ -0,0 +1,200 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Linq; +using System.Reflection; +using System.Threading; + +namespace System.Data.Common +{ + public static partial class DbProviderFactories + { + private struct ProviderRegistration + { + internal ProviderRegistration(string factoryTypeAssemblyQualifiedName, DbProviderFactory factoryInstance) + { + this.FactoryTypeAssemblyQualifiedName = factoryTypeAssemblyQualifiedName; + this.FactoryInstance = factoryInstance; + } + + internal string FactoryTypeAssemblyQualifiedName { get; } + /// + /// The cached instance of the type in . If null, this registation is seen as a deferred registration + /// and is checked the first time when this registration is requested through GetFactory(). + /// + internal DbProviderFactory FactoryInstance { get; } + } + + private static ConcurrentDictionary _registeredFactories = new ConcurrentDictionary(); + private const string AssemblyQualifiedNameColumnName = "AssemblyQualifiedName"; + private const string InvariantNameColumnName = "InvariantName"; + private const string NameColumnName = "Name"; + private const string DescriptionColumnName = "Description"; + private const string ProviderGroupColumnName = "DbProviderFactories"; + private const string InstanceFieldName = "Instance"; + + public static bool TryGetFactory(string providerInvariantName, out DbProviderFactory factory) + { + factory = GetFactory(providerInvariantName, throwOnError: false); + return factory != null; + } + + public static DbProviderFactory GetFactory(string providerInvariantName) + { + return GetFactory(providerInvariantName, throwOnError: true); + } + + public static DbProviderFactory GetFactory(DataRow providerRow) + { + ADP.CheckArgumentNull(providerRow, nameof(providerRow)); + + DataColumn assemblyQualifiedNameColumn = providerRow.Table.Columns[AssemblyQualifiedNameColumnName]; + if (null == assemblyQualifiedNameColumn) + { + throw ADP.Argument(SR.ADP_DbProviderFactories_NoAssemblyQualifiedName); + } + + string assemblyQualifiedName = providerRow[assemblyQualifiedNameColumn] as string; + if (string.IsNullOrWhiteSpace(assemblyQualifiedName)) + { + throw ADP.Argument(SR.ADP_DbProviderFactories_NoAssemblyQualifiedName); + } + + return GetFactoryInstance(GetProviderTypeFromTypeName(assemblyQualifiedName)); + } + + + public static DbProviderFactory GetFactory(DbConnection connection) + { + ADP.CheckArgumentNull(connection, nameof(connection)); + + return connection.ProviderFactory; + } + + public static DataTable GetFactoryClasses() + { + DataColumn nameColumn = new DataColumn(NameColumnName, typeof(string)) { ReadOnly = true }; + DataColumn descriptionColumn = new DataColumn(DescriptionColumnName, typeof(string)) { ReadOnly = true }; + DataColumn invariantNameColumn = new DataColumn(InvariantNameColumnName, typeof(string)) { ReadOnly = true }; + DataColumn assemblyQualifiedNameColumn = new DataColumn(AssemblyQualifiedNameColumnName, typeof(string)) { ReadOnly = true }; + + DataTable toReturn = new DataTable(ProviderGroupColumnName) { Locale = CultureInfo.InvariantCulture }; + toReturn.Columns.AddRange(new[] { nameColumn, descriptionColumn, invariantNameColumn, assemblyQualifiedNameColumn }); + toReturn.PrimaryKey = new[] { invariantNameColumn }; + foreach(var kvp in _registeredFactories) + { + DataRow newRow = toReturn.NewRow(); + newRow[InvariantNameColumnName] = kvp.Key; + newRow[AssemblyQualifiedNameColumnName] = kvp.Value.FactoryTypeAssemblyQualifiedName; + newRow[NameColumnName] = string.Empty; + newRow[DescriptionColumnName] = string.Empty; + toReturn.AddRow(newRow); + } + return toReturn; + } + + public static IEnumerable GetProviderInvariantNames() + { + return _registeredFactories.Keys.ToList(); + } + + public static void RegisterFactory(string providerInvariantName, string factoryTypeAssemblyQualifiedName) + { + ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName)); + ADP.CheckArgumentLength(factoryTypeAssemblyQualifiedName, nameof(factoryTypeAssemblyQualifiedName)); + + // this method performs a deferred registration: the type name specified is checked when the factory is requested for the first time. + _registeredFactories[providerInvariantName] = new ProviderRegistration(factoryTypeAssemblyQualifiedName, null); + } + + public static void RegisterFactory(string providerInvariantName, Type providerFactoryClass) + { + RegisterFactory(providerInvariantName, GetFactoryInstance(providerFactoryClass)); + } + + public static void RegisterFactory(string providerInvariantName, DbProviderFactory factory) + { + ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName)); + ADP.CheckArgumentNull(factory, nameof(factory)); + + _registeredFactories[providerInvariantName] = new ProviderRegistration(factory.GetType().AssemblyQualifiedName, factory); + } + + public static bool UnregisterFactory(string providerInvariantName) + { + return !string.IsNullOrWhiteSpace(providerInvariantName) && _registeredFactories.TryRemove(providerInvariantName, out _); + } + + private static DbProviderFactory GetFactory(string providerInvariantName, bool throwOnError) + { + if (throwOnError) + { + ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName)); + } + else + { + if (string.IsNullOrWhiteSpace(providerInvariantName)) + { + return null; + } + } + bool wasRegistered = _registeredFactories.TryGetValue(providerInvariantName, out ProviderRegistration registration); + if (!wasRegistered) + { + return throwOnError ? throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_InvariantNameNotFound, providerInvariantName)) : (DbProviderFactory)null; + } + DbProviderFactory toReturn = registration.FactoryInstance; + if (toReturn == null) + { + // Deferred registration, do checks now on the type specified and register instance in storage. + // Even in the case of throwOnError being false, this will throw when an exception occurs checking the registered type as the user has to be notified the + // registration is invalid, even though the registration is there. + toReturn = GetFactoryInstance(GetProviderTypeFromTypeName(registration.FactoryTypeAssemblyQualifiedName)); + RegisterFactory(providerInvariantName, toReturn); + } + return toReturn; + } + + private static DbProviderFactory GetFactoryInstance(Type providerFactoryClass) + { + ADP.CheckArgumentNull(providerFactoryClass, nameof(providerFactoryClass)); + if (!providerFactoryClass.IsSubclassOf(typeof(DbProviderFactory))) + { + throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_NotAFactoryType, providerFactoryClass.FullName)); + } + + FieldInfo providerInstance = providerFactoryClass.GetField(InstanceFieldName, BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.Static); + if (null == providerInstance) + { + throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance); + } + if (!providerInstance.FieldType.IsSubclassOf(typeof(DbProviderFactory))) + { + throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance); + } + object factory = providerInstance.GetValue(null); + if (null == factory) + { + throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance); + } + return (DbProviderFactory)factory; + } + + + private static Type GetProviderTypeFromTypeName(string assemblyQualifiedName) + { + Type providerType = Type.GetType(assemblyQualifiedName); + if (null == providerType) + { + throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_FactoryNotLoadable, assemblyQualifiedName)); + } + return providerType; + } + } +} diff --git a/src/System.Data.Common/tests/Configurations.props b/src/System.Data.Common/tests/Configurations.props index c398e42e8994..8b803e0772f2 100644 --- a/src/System.Data.Common/tests/Configurations.props +++ b/src/System.Data.Common/tests/Configurations.props @@ -2,6 +2,7 @@ + netcoreapp; netstandard; diff --git a/src/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/System.Data.Common/tests/System.Data.Common.Tests.csproj index ab5a69e6d06a..14b510272424 100644 --- a/src/System.Data.Common/tests/System.Data.Common.Tests.csproj +++ b/src/System.Data.Common/tests/System.Data.Common.Tests.csproj @@ -4,7 +4,10 @@ {B473F77D-4168-4123-932A-E88020B768FA} 0168,0169,0414,0219,0649 + $(DefineConstants);netcoreapp + + @@ -74,7 +77,9 @@ - + + Component + @@ -109,6 +114,9 @@ System\Runtime\Serialization\Formatters\BinaryFormatterHelpers.cs + + + {69e46a6f-9966-45a5-8945-2559fe337827} diff --git a/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs b/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs index 1aebdd0782b5..daac63de95c1 100644 --- a/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs +++ b/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // See the LICENSE file in the project root for more information. +using System.Reflection; using Xunit; namespace System.Data.Common.Tests @@ -94,6 +95,47 @@ protected override DbCommand CreateDbCommand() } } + private class DbProviderFactoryConnection : DbConnection + { + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) + { + throw new NotImplementedException(); + } + + public override void ChangeDatabase(string databaseName) + { + throw new NotImplementedException(); + } + + public override void Close() + { + throw new NotImplementedException(); + } + + public override void Open() + { + throw new NotImplementedException(); + } + + public override string ConnectionString { get; set; } + public override string Database { get; } + public override ConnectionState State { get; } + public override string DataSource { get; } + public override string ServerVersion { get; } + + protected override DbCommand CreateDbCommand() + { + throw new NotImplementedException(); + } + + protected override DbProviderFactory DbProviderFactory => TestDbProviderFactory.Instance; + } + + private class TestDbProviderFactory : DbProviderFactory + { + public static DbProviderFactory Instance = new TestDbProviderFactory(); + } + [Fact] [SkipOnTargetFramework(TargetFrameworkMonikers.Mono, "GC has different behavior on Mono")] public void CanBeFinalized() @@ -103,5 +145,17 @@ public void CanBeFinalized() GC.WaitForPendingFinalizers(); Assert.True(_wasFinalized); } + + [Fact] + public void ProviderFactoryTest() + { + DbProviderFactoryConnection con = new DbProviderFactoryConnection(); + PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(providerFactoryProperty); + DbProviderFactory factory = providerFactoryProperty.GetValue(con) as DbProviderFactory; + Assert.NotNull(factory); + Assert.Same(typeof(TestDbProviderFactory), factory.GetType()); + Assert.Same(TestDbProviderFactory.Instance, factory); + } } } diff --git a/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs b/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs new file mode 100644 index 000000000000..204523823861 --- /dev/null +++ b/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs @@ -0,0 +1,207 @@ +// Licensed to the .NET Foundation under one or more agreements. +// See the LICENSE file in the project root for more information. + +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Reflection; +using System.Data.Common; +using System.Linq; +using Xunit; + +namespace System.Data.Common +{ + public sealed class TestProviderFactory : DbProviderFactory + { + public static readonly TestProviderFactory Instance = new TestProviderFactory(); + private TestProviderFactory() { } + } + + public class DbProviderFactoriesTests + { + [Fact] + public void GetFactoryClassesDataTableShapeTest() + { + DataTable initializedTable = DbProviderFactories.GetFactoryClasses(); + Assert.NotNull(initializedTable); + Assert.Equal(4, initializedTable.Columns.Count); + Assert.Equal("Name", initializedTable.Columns[0].ColumnName); + Assert.Equal("Description", initializedTable.Columns[1].ColumnName); + Assert.Equal("InvariantName", initializedTable.Columns[2].ColumnName); + Assert.Equal("AssemblyQualifiedName", initializedTable.Columns[3].ColumnName); + } + + [Fact] + public void GetFactoryNoRegistrationTest() + { + ClearRegisteredFactories(); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + } + + [Fact] + public void GetFactoryWithInvariantNameTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + DbProviderFactory factory = DbProviderFactories.GetFactory("System.Data.SqlClient"); + Assert.NotNull(factory); + Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType()); + Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory); + } + + [Fact] + public void GetFactoryWithDbConnectionTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + DbProviderFactory factory = DbProviderFactories.GetFactory(new System.Data.SqlClient.SqlConnection()); + Assert.NotNull(factory); + Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType()); + Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory); + } + + [Fact] + public void GetFactoryWithDataRowTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=> DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + } + + [Fact] + public void RegisterFactoryWithTypeNameTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory).AssemblyQualifiedName)); + } + + [Fact] + public void RegisterFactoryWithTypeTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + } + + [Fact] + public void RegisterFactoryWithInstanceTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance)); + } + + [Fact] + public void RegisterFactoryWithWrongTypeTest() + { + ClearRegisteredFactories(); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + Assert.Throws(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlConnection))); + } + + [Fact] + public void RegisterFactoryWithBadInvariantNameTest() + { + ClearRegisteredFactories(); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + Assert.Throws(() => DbProviderFactories.RegisterFactory(string.Empty, typeof(System.Data.SqlClient.SqlClientFactory))); + } + + [Fact] + public void RegisterFactoryWithAssemblyQualifiedNameTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory).AssemblyQualifiedName)); + } + + [Fact] + public void RegisterFactoryWithWrongAssemblyQualifiedNameTest() + { + ClearRegisteredFactories(); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + DataTable providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(0, providerTable.Rows.Count); + // register the connection type which is the wrong type. Registraton should succeed, as type registration/checking is deferred. + DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlConnection).AssemblyQualifiedName); + providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(1, providerTable.Rows.Count); + // obtaining the factory will kick in the checks of the registered type name, which will cause exceptions. The checks were deferred till the GetFactory() call. + Assert.Throws(() => DbProviderFactories.GetFactory(providerTable.Rows[0])); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + } + + [Fact] + public void UnregisterFactoryTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance)); + Assert.True(DbProviderFactories.UnregisterFactory("System.Data.SqlClient")); + DataTable providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(0, providerTable.Rows.Count); + } + + [Fact] + public void TryGetFactoryTest() + { + ClearRegisteredFactories(); + Assert.False(DbProviderFactories.TryGetFactory("System.Data.SqlClient", out DbProviderFactory f)); + RegisterSqlClientAndTestRegistration(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance)); + Assert.True(DbProviderFactories.TryGetFactory("System.Data.SqlClient", out DbProviderFactory factory)); + Assert.NotNull(factory); + Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType()); + Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory); + } + + [Fact] + public void ReplaceFactoryWithRegisterFactoryWithTypeTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(TestProviderFactory)); + DataTable providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(1, providerTable.Rows.Count); + DbProviderFactory factory = DbProviderFactories.GetFactory("System.Data.SqlClient"); + Assert.NotNull(factory); + Assert.Equal(typeof(TestProviderFactory), factory.GetType()); + Assert.Equal(TestProviderFactory.Instance, factory); + } + + [Fact] + public void GetProviderInvariantNamesTest() + { + ClearRegisteredFactories(); + RegisterSqlClientAndTestRegistration(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory))); + DbProviderFactories.RegisterFactory("System.Data.Common.TestProvider", typeof(TestProviderFactory)); + DataTable providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(2, providerTable.Rows.Count); + List invariantNames = DbProviderFactories.GetProviderInvariantNames().ToList(); + Assert.Equal(invariantNames.Count, 2); + Assert.True(invariantNames.Contains("System.Data.Common.TestProvider")); + Assert.True(invariantNames.Contains("System.Data.SqlClient")); + } + + private void ClearRegisteredFactories() + { + // as the DbProviderFactories table is shared, for tests we need a clean one before a test starts to make sure the tests always succeed. + Type type = typeof(DbProviderFactories); + FieldInfo info = type.GetField("_registeredFactories", BindingFlags.NonPublic | BindingFlags.Static); + IDictionary providerStorage = info.GetValue(null) as IDictionary; + Assert.NotNull(providerStorage); + providerStorage.Clear(); + Assert.Equal(0, providerStorage.Count); + } + + + private void RegisterSqlClientAndTestRegistration(Action registrationFunc) + { + Assert.NotNull(registrationFunc); + Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient")); + DataTable providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(0, providerTable.Rows.Count); + registrationFunc(); + providerTable = DbProviderFactories.GetFactoryClasses(); + Assert.Equal(1, providerTable.Rows.Count); + DbProviderFactory factory = DbProviderFactories.GetFactory(providerTable.Rows[0]); + Assert.NotNull(factory); + Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType()); + Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory); + } + } +}