Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

force aggressive JIT optimization and disable inlining for patching methods #21

Merged
merged 1 commit into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 94 additions & 81 deletions src/Softwarehelden.Transactions.Oletx/MsSqlPatcher.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using HarmonyLib;
using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace Softwarehelden.Transactions.Oletx
{
Expand All @@ -20,84 +22,89 @@ public static class MsSqlPatcher
/// <param name="assembly">Microsoft.Data.SqlClient or System.Data.SqlClient assembly</param>
public static void Patch(Assembly assembly)
{
if (assembly == null)
// Enable MSSQL data provider patch for Windows only since distributed transactions only
// work on Windows
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
throw new ArgumentNullException(nameof(assembly), "The Microsoft.Data.SqlClient or System.Data.SqlClient assembly must be set.");
}
if (assembly == null)
{
throw new ArgumentNullException(nameof(assembly), "The Microsoft.Data.SqlClient or System.Data.SqlClient assembly must be set.");
}

string assemblyName = assembly.GetName().Name;

var tdsParserType = assembly.GetType($"{assemblyName}.TdsParser");
var tdsParserStateObjectType = assembly.GetType($"{assemblyName}.TdsParserStateObject");
var sqlInternalTransactionType = assembly.GetType($"{assemblyName}.SqlInternalTransaction");

var writeMarsHeaderDataMethod = tdsParserType.GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeMarsHeaderDataNewMethod = typeof(Patches).GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Static | BindingFlags.Public
);

var writeShortMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeShortReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Static | BindingFlags.Public
);

var writeLongMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeLongReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Static | BindingFlags.Public
);

var writeIntMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeIntReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Static | BindingFlags.Public
);

var incrementAndObtainOpenResultCountMethod = tdsParserStateObjectType.GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Instance | BindingFlags.NonPublic
);

var incrementAndObtainOpenResultCountReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Static | BindingFlags.Public
);

var getTransactionIdMethodMethod = sqlInternalTransactionType.GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Instance | BindingFlags.NonPublic
);

var getTransactionIdMethodReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Static | BindingFlags.Public
);

MethodPatcher.Patch(writeMarsHeaderDataMethod, new HarmonyMethod(writeMarsHeaderDataNewMethod));

MethodPatcher.CreateReversePatcher(writeShortMethod, new HarmonyMethod(writeShortReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(writeLongMethod, new HarmonyMethod(writeLongReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(writeIntMethod, new HarmonyMethod(writeIntReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(incrementAndObtainOpenResultCountMethod, new HarmonyMethod(incrementAndObtainOpenResultCountReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(getTransactionIdMethodMethod, new HarmonyMethod(getTransactionIdMethodReverseMethod)).Patch();
string assemblyName = assembly.GetName().Name;

var tdsParserType = assembly.GetType($"{assemblyName}.TdsParser");
var tdsParserStateObjectType = assembly.GetType($"{assemblyName}.TdsParserStateObject");
var sqlInternalTransactionType = assembly.GetType($"{assemblyName}.SqlInternalTransaction");

var writeMarsHeaderDataMethod = tdsParserType.GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeMarsHeaderDataNewMethod = typeof(Patches).GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Static | BindingFlags.Public
);

var writeShortMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeShortReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Static | BindingFlags.Public
);

var writeLongMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeLongReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Static | BindingFlags.Public
);

var writeIntMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeIntReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Static | BindingFlags.Public
);

var incrementAndObtainOpenResultCountMethod = tdsParserStateObjectType.GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Instance | BindingFlags.NonPublic
);

var incrementAndObtainOpenResultCountReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Static | BindingFlags.Public
);

var getTransactionIdMethodMethod = sqlInternalTransactionType.GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Instance | BindingFlags.NonPublic
);

var getTransactionIdMethodReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Static | BindingFlags.Public
);

MethodPatcher.Patch(writeMarsHeaderDataMethod, new HarmonyMethod(writeMarsHeaderDataNewMethod));

Harmony.ReversePatch(writeShortMethod, new HarmonyMethod(writeShortReverseMethod));
Harmony.ReversePatch(writeLongMethod, new HarmonyMethod(writeLongReverseMethod));
Harmony.ReversePatch(writeIntMethod, new HarmonyMethod(writeIntReverseMethod));
Harmony.ReversePatch(incrementAndObtainOpenResultCountMethod, new HarmonyMethod(incrementAndObtainOpenResultCountReverseMethod));
Harmony.ReversePatch(getTransactionIdMethodMethod, new HarmonyMethod(getTransactionIdMethodReverseMethod));
}
}

/// <summary>
Expand All @@ -115,6 +122,7 @@ private static class Patches
/// https://github.com/dotnet/SqlClient/issues/1623
/// </summary>
/// <remarks>https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L10737</remarks>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool WriteMarsHeaderData(object __instance, long ____retainedTransactionId, object stateObj, object transaction)
{
ReversePatches.WriteShort(__instance, HEADERTYPE_MARS, stateObj);
Expand Down Expand Up @@ -144,41 +152,46 @@ private static class ReversePatches
/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalTransaction.cs#L152
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static long get_TransactionId(object instance)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(get_TransactionId)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs#L931
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static int IncrementAndObtainOpenResultCount(object instance, object transaction)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(IncrementAndObtainOpenResultCount)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1643
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteInt(object instance, int v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteInt)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1721
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteLong(object instance, long v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteLong)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1590
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteShort(object instance, int v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteShort)}' stub must not be called.");
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/Softwarehelden.Transactions.Oletx/OletxPatcher.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using HarmonyLib;
using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Transactions;

Expand Down Expand Up @@ -89,6 +90,7 @@ private static class Patches
/// <summary>
/// Performs a PSPE enlistment with a non-MSDTC promoter type.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void EnlistPromotableSinglePhase(
ref Transaction __instance,
ref IPromotableSinglePhaseNotification promotableSinglePhaseNotification,
Expand All @@ -115,6 +117,7 @@ ref Guid promoterType
/// Returns a transaction cookie that is used to propagate/import a distributed
/// transaction on a SQL server that wants to participate in the transaction.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool GetExportCookie(Transaction transaction, byte[] whereabouts, ref byte[] __result)
{
// Should the transaction be promoted using our custom promoter type?
Expand Down Expand Up @@ -164,6 +167,7 @@ public static bool GetExportCookie(Transaction transaction, byte[] whereabouts,
/// <summary>
/// Returns the native DTC transaction for the given transaction instance.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool GetTransactionNative(Transaction transaction, ref IDtcTransaction __result)
{
if (transaction.PromoterType == NonMsdtcPromoterType)
Expand Down