diff --git a/src/System.Data.Common/ref/System.Data.Common.cs b/src/System.Data.Common/ref/System.Data.Common.cs
index 854d72c0bb4c..a0ec1d683847 100644
--- a/src/System.Data.Common/ref/System.Data.Common.cs
+++ b/src/System.Data.Common/ref/System.Data.Common.cs
@@ -2211,6 +2211,13 @@ protected DbProviderFactory() { }
public virtual System.Data.Common.DbDataSourceEnumerator CreateDataSourceEnumerator() { throw null; }
public virtual System.Data.Common.DbParameter CreateParameter() { throw null; }
}
+ public static partial class DbProviderFactories
+ {
+ public static DbProviderFactory GetFactory(string providerInvariantName) { throw null; }
+ public static DbProviderFactory GetFactory(DataRow providerRow) { throw null; }
+ public static DbProviderFactory GetFactory(DbConnection connection) { throw null; }
+ public static DataTable GetFactoryClasses() { throw null; }
+ }
[System.AttributeUsageAttribute((System.AttributeTargets)(128), AllowMultiple = false, Inherited = true)]
public sealed partial class DbProviderSpecificTypePropertyAttribute : System.Attribute
{
diff --git a/src/System.Data.Common/ref/System.Data.Common.csproj b/src/System.Data.Common/ref/System.Data.Common.csproj
index f3525231a91f..6dc2f123b3be 100644
--- a/src/System.Data.Common/ref/System.Data.Common.csproj
+++ b/src/System.Data.Common/ref/System.Data.Common.csproj
@@ -12,6 +12,9 @@
+
+
+
diff --git a/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs b/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs
new file mode 100644
index 000000000000..94df2006ad59
--- /dev/null
+++ b/src/System.Data.Common/ref/System.Data.Common.netcoreapp.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+// ------------------------------------------------------------------------------
+// Changes to this file must follow the http://aka.ms/api-review process.
+// ------------------------------------------------------------------------------
+
+using System.Collections.Generic;
+
+namespace System.Data.Common
+{
+ public static partial class DbProviderFactories
+ {
+ public static void RegisterFactory(string providerInvariantName, string factoryTypeAssemblyQualifiedName) { throw null; }
+ public static void RegisterFactory(string providerInvariantName, Type factoryType) { throw null; }
+ public static void RegisterFactory(string providerInvariantName, DbProviderFactory factory) { throw null; }
+ public static bool TryGetFactory(string providerInvariantName, out DbProviderFactory factory) { throw null; }
+ public static bool UnregisterFactory(string providerInvariantName) { throw null; }
+ public static IEnumerable GetProviderInvariantNames() { throw null; }
+ }
+}
diff --git a/src/System.Data.Common/src/Resources/Strings.resx b/src/System.Data.Common/src/Resources/Strings.resx
index 8ac6107e7fc7..f13b9de2d8b7 100644
--- a/src/System.Data.Common/src/Resources/Strings.resx
+++ b/src/System.Data.Common/src/Resources/Strings.resx
@@ -512,4 +512,9 @@
Cannot remove this column, because it is part of an expression: {0} = {1}.
The rowOrder value={0} has been found twice for table named '{1}'.
Cannot find ElementType name='{0}'.
+ The specified invariant name '{0}' wasn't found in the list of registered .NET Data Providers.
+ The requested .NET Data Provider's implementation does not have an Instance field of a System.Data.Common.DbProviderFactory derived type.
+ The registered .NET Data Provider's DbProviderFactory implementation type '{0}' couldn't be loaded.
+ The missing .NET Data Provider's assembly qualified name is required.
+ The type '{0}' doesn't inherit from DbProviderFactory.
diff --git a/src/System.Data.Common/src/System.Data.Common.csproj b/src/System.Data.Common/src/System.Data.Common.csproj
index f510044d987d..5e3d53dee503 100644
--- a/src/System.Data.Common/src/System.Data.Common.csproj
+++ b/src/System.Data.Common/src/System.Data.Common.csproj
@@ -57,7 +57,9 @@
-
+
+ Component
+
@@ -78,10 +80,14 @@
-
+
+ Component
+
-
+
+ Component
+
@@ -91,9 +97,13 @@
-
+
+ Component
+
-
+
+ Component
+
@@ -135,7 +145,9 @@
-
+
+ Component
+
@@ -160,7 +172,9 @@
-
+
+ Component
+
@@ -170,9 +184,15 @@
-
-
-
+
+ Component
+
+
+ Component
+
+
+ Component
+
System\Data\Common\DbConnectionOptions.Common.cs
@@ -183,7 +203,9 @@
-
+
+ Component
+
@@ -195,6 +217,7 @@
+
diff --git a/src/System.Data.Common/src/System/Data/Common/DbConnection.cs b/src/System.Data.Common/src/System/Data/Common/DbConnection.cs
index 3880a544ac7d..766bb516e7b9 100644
--- a/src/System.Data.Common/src/System/Data/Common/DbConnection.cs
+++ b/src/System.Data.Common/src/System/Data/Common/DbConnection.cs
@@ -37,6 +37,8 @@ protected DbConnection() : base()
///
protected virtual DbProviderFactory DbProviderFactory => null;
+ internal DbProviderFactory ProviderFactory => DbProviderFactory;
+
[Browsable(false)]
public abstract string ServerVersion { get; }
diff --git a/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs b/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs
new file mode 100644
index 000000000000..bc6c482faaf4
--- /dev/null
+++ b/src/System.Data.Common/src/System/Data/Common/DbProviderFactories.cs
@@ -0,0 +1,200 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Globalization;
+using System.Linq;
+using System.Reflection;
+using System.Threading;
+
+namespace System.Data.Common
+{
+ public static partial class DbProviderFactories
+ {
+ private struct ProviderRegistration
+ {
+ internal ProviderRegistration(string factoryTypeAssemblyQualifiedName, DbProviderFactory factoryInstance)
+ {
+ this.FactoryTypeAssemblyQualifiedName = factoryTypeAssemblyQualifiedName;
+ this.FactoryInstance = factoryInstance;
+ }
+
+ internal string FactoryTypeAssemblyQualifiedName { get; }
+ ///
+ /// The cached instance of the type in . If null, this registation is seen as a deferred registration
+ /// and is checked the first time when this registration is requested through GetFactory().
+ ///
+ internal DbProviderFactory FactoryInstance { get; }
+ }
+
+ private static ConcurrentDictionary _registeredFactories = new ConcurrentDictionary();
+ private const string AssemblyQualifiedNameColumnName = "AssemblyQualifiedName";
+ private const string InvariantNameColumnName = "InvariantName";
+ private const string NameColumnName = "Name";
+ private const string DescriptionColumnName = "Description";
+ private const string ProviderGroupColumnName = "DbProviderFactories";
+ private const string InstanceFieldName = "Instance";
+
+ public static bool TryGetFactory(string providerInvariantName, out DbProviderFactory factory)
+ {
+ factory = GetFactory(providerInvariantName, throwOnError: false);
+ return factory != null;
+ }
+
+ public static DbProviderFactory GetFactory(string providerInvariantName)
+ {
+ return GetFactory(providerInvariantName, throwOnError: true);
+ }
+
+ public static DbProviderFactory GetFactory(DataRow providerRow)
+ {
+ ADP.CheckArgumentNull(providerRow, nameof(providerRow));
+
+ DataColumn assemblyQualifiedNameColumn = providerRow.Table.Columns[AssemblyQualifiedNameColumnName];
+ if (null == assemblyQualifiedNameColumn)
+ {
+ throw ADP.Argument(SR.ADP_DbProviderFactories_NoAssemblyQualifiedName);
+ }
+
+ string assemblyQualifiedName = providerRow[assemblyQualifiedNameColumn] as string;
+ if (string.IsNullOrWhiteSpace(assemblyQualifiedName))
+ {
+ throw ADP.Argument(SR.ADP_DbProviderFactories_NoAssemblyQualifiedName);
+ }
+
+ return GetFactoryInstance(GetProviderTypeFromTypeName(assemblyQualifiedName));
+ }
+
+
+ public static DbProviderFactory GetFactory(DbConnection connection)
+ {
+ ADP.CheckArgumentNull(connection, nameof(connection));
+
+ return connection.ProviderFactory;
+ }
+
+ public static DataTable GetFactoryClasses()
+ {
+ DataColumn nameColumn = new DataColumn(NameColumnName, typeof(string)) { ReadOnly = true };
+ DataColumn descriptionColumn = new DataColumn(DescriptionColumnName, typeof(string)) { ReadOnly = true };
+ DataColumn invariantNameColumn = new DataColumn(InvariantNameColumnName, typeof(string)) { ReadOnly = true };
+ DataColumn assemblyQualifiedNameColumn = new DataColumn(AssemblyQualifiedNameColumnName, typeof(string)) { ReadOnly = true };
+
+ DataTable toReturn = new DataTable(ProviderGroupColumnName) { Locale = CultureInfo.InvariantCulture };
+ toReturn.Columns.AddRange(new[] { nameColumn, descriptionColumn, invariantNameColumn, assemblyQualifiedNameColumn });
+ toReturn.PrimaryKey = new[] { invariantNameColumn };
+ foreach(var kvp in _registeredFactories)
+ {
+ DataRow newRow = toReturn.NewRow();
+ newRow[InvariantNameColumnName] = kvp.Key;
+ newRow[AssemblyQualifiedNameColumnName] = kvp.Value.FactoryTypeAssemblyQualifiedName;
+ newRow[NameColumnName] = string.Empty;
+ newRow[DescriptionColumnName] = string.Empty;
+ toReturn.AddRow(newRow);
+ }
+ return toReturn;
+ }
+
+ public static IEnumerable GetProviderInvariantNames()
+ {
+ return _registeredFactories.Keys.ToList();
+ }
+
+ public static void RegisterFactory(string providerInvariantName, string factoryTypeAssemblyQualifiedName)
+ {
+ ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName));
+ ADP.CheckArgumentLength(factoryTypeAssemblyQualifiedName, nameof(factoryTypeAssemblyQualifiedName));
+
+ // this method performs a deferred registration: the type name specified is checked when the factory is requested for the first time.
+ _registeredFactories[providerInvariantName] = new ProviderRegistration(factoryTypeAssemblyQualifiedName, null);
+ }
+
+ public static void RegisterFactory(string providerInvariantName, Type providerFactoryClass)
+ {
+ RegisterFactory(providerInvariantName, GetFactoryInstance(providerFactoryClass));
+ }
+
+ public static void RegisterFactory(string providerInvariantName, DbProviderFactory factory)
+ {
+ ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName));
+ ADP.CheckArgumentNull(factory, nameof(factory));
+
+ _registeredFactories[providerInvariantName] = new ProviderRegistration(factory.GetType().AssemblyQualifiedName, factory);
+ }
+
+ public static bool UnregisterFactory(string providerInvariantName)
+ {
+ return !string.IsNullOrWhiteSpace(providerInvariantName) && _registeredFactories.TryRemove(providerInvariantName, out _);
+ }
+
+ private static DbProviderFactory GetFactory(string providerInvariantName, bool throwOnError)
+ {
+ if (throwOnError)
+ {
+ ADP.CheckArgumentLength(providerInvariantName, nameof(providerInvariantName));
+ }
+ else
+ {
+ if (string.IsNullOrWhiteSpace(providerInvariantName))
+ {
+ return null;
+ }
+ }
+ bool wasRegistered = _registeredFactories.TryGetValue(providerInvariantName, out ProviderRegistration registration);
+ if (!wasRegistered)
+ {
+ return throwOnError ? throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_InvariantNameNotFound, providerInvariantName)) : (DbProviderFactory)null;
+ }
+ DbProviderFactory toReturn = registration.FactoryInstance;
+ if (toReturn == null)
+ {
+ // Deferred registration, do checks now on the type specified and register instance in storage.
+ // Even in the case of throwOnError being false, this will throw when an exception occurs checking the registered type as the user has to be notified the
+ // registration is invalid, even though the registration is there.
+ toReturn = GetFactoryInstance(GetProviderTypeFromTypeName(registration.FactoryTypeAssemblyQualifiedName));
+ RegisterFactory(providerInvariantName, toReturn);
+ }
+ return toReturn;
+ }
+
+ private static DbProviderFactory GetFactoryInstance(Type providerFactoryClass)
+ {
+ ADP.CheckArgumentNull(providerFactoryClass, nameof(providerFactoryClass));
+ if (!providerFactoryClass.IsSubclassOf(typeof(DbProviderFactory)))
+ {
+ throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_NotAFactoryType, providerFactoryClass.FullName));
+ }
+
+ FieldInfo providerInstance = providerFactoryClass.GetField(InstanceFieldName, BindingFlags.DeclaredOnly | BindingFlags.Public | BindingFlags.Static);
+ if (null == providerInstance)
+ {
+ throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance);
+ }
+ if (!providerInstance.FieldType.IsSubclassOf(typeof(DbProviderFactory)))
+ {
+ throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance);
+ }
+ object factory = providerInstance.GetValue(null);
+ if (null == factory)
+ {
+ throw ADP.InvalidOperation(SR.ADP_DbProviderFactories_NoInstance);
+ }
+ return (DbProviderFactory)factory;
+ }
+
+
+ private static Type GetProviderTypeFromTypeName(string assemblyQualifiedName)
+ {
+ Type providerType = Type.GetType(assemblyQualifiedName);
+ if (null == providerType)
+ {
+ throw ADP.Argument(SR.Format(SR.ADP_DbProviderFactories_FactoryNotLoadable, assemblyQualifiedName));
+ }
+ return providerType;
+ }
+ }
+}
diff --git a/src/System.Data.Common/tests/Configurations.props b/src/System.Data.Common/tests/Configurations.props
index c398e42e8994..8b803e0772f2 100644
--- a/src/System.Data.Common/tests/Configurations.props
+++ b/src/System.Data.Common/tests/Configurations.props
@@ -2,6 +2,7 @@
+ netcoreapp;
netstandard;
diff --git a/src/System.Data.Common/tests/System.Data.Common.Tests.csproj b/src/System.Data.Common/tests/System.Data.Common.Tests.csproj
index ab5a69e6d06a..14b510272424 100644
--- a/src/System.Data.Common/tests/System.Data.Common.Tests.csproj
+++ b/src/System.Data.Common/tests/System.Data.Common.Tests.csproj
@@ -4,7 +4,10 @@
{B473F77D-4168-4123-932A-E88020B768FA}
0168,0169,0414,0219,0649
+ $(DefineConstants);netcoreapp
+
+
@@ -74,7 +77,9 @@
-
+
+ Component
+
@@ -109,6 +114,9 @@
System\Runtime\Serialization\Formatters\BinaryFormatterHelpers.cs
+
+
+
{69e46a6f-9966-45a5-8945-2559fe337827}
diff --git a/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs b/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs
index 1aebdd0782b5..daac63de95c1 100644
--- a/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs
+++ b/src/System.Data.Common/tests/System/Data/Common/DbConnectionTests.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// See the LICENSE file in the project root for more information.
+using System.Reflection;
using Xunit;
namespace System.Data.Common.Tests
@@ -94,6 +95,47 @@ protected override DbCommand CreateDbCommand()
}
}
+ private class DbProviderFactoryConnection : DbConnection
+ {
+ protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override void ChangeDatabase(string databaseName)
+ {
+ throw new NotImplementedException();
+ }
+
+ public override void Close()
+ {
+ throw new NotImplementedException();
+ }
+
+ public override void Open()
+ {
+ throw new NotImplementedException();
+ }
+
+ public override string ConnectionString { get; set; }
+ public override string Database { get; }
+ public override ConnectionState State { get; }
+ public override string DataSource { get; }
+ public override string ServerVersion { get; }
+
+ protected override DbCommand CreateDbCommand()
+ {
+ throw new NotImplementedException();
+ }
+
+ protected override DbProviderFactory DbProviderFactory => TestDbProviderFactory.Instance;
+ }
+
+ private class TestDbProviderFactory : DbProviderFactory
+ {
+ public static DbProviderFactory Instance = new TestDbProviderFactory();
+ }
+
[Fact]
[SkipOnTargetFramework(TargetFrameworkMonikers.Mono, "GC has different behavior on Mono")]
public void CanBeFinalized()
@@ -103,5 +145,17 @@ public void CanBeFinalized()
GC.WaitForPendingFinalizers();
Assert.True(_wasFinalized);
}
+
+ [Fact]
+ public void ProviderFactoryTest()
+ {
+ DbProviderFactoryConnection con = new DbProviderFactoryConnection();
+ PropertyInfo providerFactoryProperty = con.GetType().GetProperty("ProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance);
+ Assert.NotNull(providerFactoryProperty);
+ DbProviderFactory factory = providerFactoryProperty.GetValue(con) as DbProviderFactory;
+ Assert.NotNull(factory);
+ Assert.Same(typeof(TestDbProviderFactory), factory.GetType());
+ Assert.Same(TestDbProviderFactory.Instance, factory);
+ }
}
}
diff --git a/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs b/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs
new file mode 100644
index 000000000000..204523823861
--- /dev/null
+++ b/src/System.Data.Common/tests/System/Data/Common/DbProviderFactoriesTests.netcoreapp.cs
@@ -0,0 +1,207 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Data.Common;
+using System.Linq;
+using Xunit;
+
+namespace System.Data.Common
+{
+ public sealed class TestProviderFactory : DbProviderFactory
+ {
+ public static readonly TestProviderFactory Instance = new TestProviderFactory();
+ private TestProviderFactory() { }
+ }
+
+ public class DbProviderFactoriesTests
+ {
+ [Fact]
+ public void GetFactoryClassesDataTableShapeTest()
+ {
+ DataTable initializedTable = DbProviderFactories.GetFactoryClasses();
+ Assert.NotNull(initializedTable);
+ Assert.Equal(4, initializedTable.Columns.Count);
+ Assert.Equal("Name", initializedTable.Columns[0].ColumnName);
+ Assert.Equal("Description", initializedTable.Columns[1].ColumnName);
+ Assert.Equal("InvariantName", initializedTable.Columns[2].ColumnName);
+ Assert.Equal("AssemblyQualifiedName", initializedTable.Columns[3].ColumnName);
+ }
+
+ [Fact]
+ public void GetFactoryNoRegistrationTest()
+ {
+ ClearRegisteredFactories();
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ }
+
+ [Fact]
+ public void GetFactoryWithInvariantNameTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ DbProviderFactory factory = DbProviderFactories.GetFactory("System.Data.SqlClient");
+ Assert.NotNull(factory);
+ Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType());
+ Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory);
+ }
+
+ [Fact]
+ public void GetFactoryWithDbConnectionTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ DbProviderFactory factory = DbProviderFactories.GetFactory(new System.Data.SqlClient.SqlConnection());
+ Assert.NotNull(factory);
+ Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType());
+ Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory);
+ }
+
+ [Fact]
+ public void GetFactoryWithDataRowTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=> DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithTypeNameTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory).AssemblyQualifiedName));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithTypeTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithInstanceTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithWrongTypeTest()
+ {
+ ClearRegisteredFactories();
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ Assert.Throws(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlConnection)));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithBadInvariantNameTest()
+ {
+ ClearRegisteredFactories();
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ Assert.Throws(() => DbProviderFactories.RegisterFactory(string.Empty, typeof(System.Data.SqlClient.SqlClientFactory)));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithAssemblyQualifiedNameTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory).AssemblyQualifiedName));
+ }
+
+ [Fact]
+ public void RegisterFactoryWithWrongAssemblyQualifiedNameTest()
+ {
+ ClearRegisteredFactories();
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ DataTable providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(0, providerTable.Rows.Count);
+ // register the connection type which is the wrong type. Registraton should succeed, as type registration/checking is deferred.
+ DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlConnection).AssemblyQualifiedName);
+ providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(1, providerTable.Rows.Count);
+ // obtaining the factory will kick in the checks of the registered type name, which will cause exceptions. The checks were deferred till the GetFactory() call.
+ Assert.Throws(() => DbProviderFactories.GetFactory(providerTable.Rows[0]));
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ }
+
+ [Fact]
+ public void UnregisterFactoryTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance));
+ Assert.True(DbProviderFactories.UnregisterFactory("System.Data.SqlClient"));
+ DataTable providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(0, providerTable.Rows.Count);
+ }
+
+ [Fact]
+ public void TryGetFactoryTest()
+ {
+ ClearRegisteredFactories();
+ Assert.False(DbProviderFactories.TryGetFactory("System.Data.SqlClient", out DbProviderFactory f));
+ RegisterSqlClientAndTestRegistration(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", System.Data.SqlClient.SqlClientFactory.Instance));
+ Assert.True(DbProviderFactories.TryGetFactory("System.Data.SqlClient", out DbProviderFactory factory));
+ Assert.NotNull(factory);
+ Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType());
+ Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory);
+ }
+
+ [Fact]
+ public void ReplaceFactoryWithRegisterFactoryWithTypeTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(()=>DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(TestProviderFactory));
+ DataTable providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(1, providerTable.Rows.Count);
+ DbProviderFactory factory = DbProviderFactories.GetFactory("System.Data.SqlClient");
+ Assert.NotNull(factory);
+ Assert.Equal(typeof(TestProviderFactory), factory.GetType());
+ Assert.Equal(TestProviderFactory.Instance, factory);
+ }
+
+ [Fact]
+ public void GetProviderInvariantNamesTest()
+ {
+ ClearRegisteredFactories();
+ RegisterSqlClientAndTestRegistration(() => DbProviderFactories.RegisterFactory("System.Data.SqlClient", typeof(System.Data.SqlClient.SqlClientFactory)));
+ DbProviderFactories.RegisterFactory("System.Data.Common.TestProvider", typeof(TestProviderFactory));
+ DataTable providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(2, providerTable.Rows.Count);
+ List invariantNames = DbProviderFactories.GetProviderInvariantNames().ToList();
+ Assert.Equal(invariantNames.Count, 2);
+ Assert.True(invariantNames.Contains("System.Data.Common.TestProvider"));
+ Assert.True(invariantNames.Contains("System.Data.SqlClient"));
+ }
+
+ private void ClearRegisteredFactories()
+ {
+ // as the DbProviderFactories table is shared, for tests we need a clean one before a test starts to make sure the tests always succeed.
+ Type type = typeof(DbProviderFactories);
+ FieldInfo info = type.GetField("_registeredFactories", BindingFlags.NonPublic | BindingFlags.Static);
+ IDictionary providerStorage = info.GetValue(null) as IDictionary;
+ Assert.NotNull(providerStorage);
+ providerStorage.Clear();
+ Assert.Equal(0, providerStorage.Count);
+ }
+
+
+ private void RegisterSqlClientAndTestRegistration(Action registrationFunc)
+ {
+ Assert.NotNull(registrationFunc);
+ Assert.Throws(() => DbProviderFactories.GetFactory("System.Data.SqlClient"));
+ DataTable providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(0, providerTable.Rows.Count);
+ registrationFunc();
+ providerTable = DbProviderFactories.GetFactoryClasses();
+ Assert.Equal(1, providerTable.Rows.Count);
+ DbProviderFactory factory = DbProviderFactories.GetFactory(providerTable.Rows[0]);
+ Assert.NotNull(factory);
+ Assert.Equal(typeof(System.Data.SqlClient.SqlClientFactory), factory.GetType());
+ Assert.Equal(System.Data.SqlClient.SqlClientFactory.Instance, factory);
+ }
+ }
+}