Skip to content

Commit

Permalink
force aggressive JIT optimization and disable inlining for patchers (#21
Browse files Browse the repository at this point in the history
)
  • Loading branch information
swh-cb authored May 27, 2022
1 parent 924adf1 commit 8e968e2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 81 deletions.
175 changes: 94 additions & 81 deletions src/Softwarehelden.Transactions.Oletx/MsSqlPatcher.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using HarmonyLib;
using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace Softwarehelden.Transactions.Oletx
{
Expand All @@ -20,84 +22,89 @@ public static class MsSqlPatcher
/// <param name="assembly">Microsoft.Data.SqlClient or System.Data.SqlClient assembly</param>
public static void Patch(Assembly assembly)
{
if (assembly == null)
// Enable MSSQL data provider patch for Windows only since distributed transactions only
// work on Windows
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
throw new ArgumentNullException(nameof(assembly), "The Microsoft.Data.SqlClient or System.Data.SqlClient assembly must be set.");
}
if (assembly == null)
{
throw new ArgumentNullException(nameof(assembly), "The Microsoft.Data.SqlClient or System.Data.SqlClient assembly must be set.");
}

string assemblyName = assembly.GetName().Name;

var tdsParserType = assembly.GetType($"{assemblyName}.TdsParser");
var tdsParserStateObjectType = assembly.GetType($"{assemblyName}.TdsParserStateObject");
var sqlInternalTransactionType = assembly.GetType($"{assemblyName}.SqlInternalTransaction");

var writeMarsHeaderDataMethod = tdsParserType.GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeMarsHeaderDataNewMethod = typeof(Patches).GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Static | BindingFlags.Public
);

var writeShortMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeShortReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Static | BindingFlags.Public
);

var writeLongMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeLongReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Static | BindingFlags.Public
);

var writeIntMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeIntReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Static | BindingFlags.Public
);

var incrementAndObtainOpenResultCountMethod = tdsParserStateObjectType.GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Instance | BindingFlags.NonPublic
);

var incrementAndObtainOpenResultCountReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Static | BindingFlags.Public
);

var getTransactionIdMethodMethod = sqlInternalTransactionType.GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Instance | BindingFlags.NonPublic
);

var getTransactionIdMethodReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Static | BindingFlags.Public
);

MethodPatcher.Patch(writeMarsHeaderDataMethod, new HarmonyMethod(writeMarsHeaderDataNewMethod));

MethodPatcher.CreateReversePatcher(writeShortMethod, new HarmonyMethod(writeShortReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(writeLongMethod, new HarmonyMethod(writeLongReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(writeIntMethod, new HarmonyMethod(writeIntReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(incrementAndObtainOpenResultCountMethod, new HarmonyMethod(incrementAndObtainOpenResultCountReverseMethod)).Patch();
MethodPatcher.CreateReversePatcher(getTransactionIdMethodMethod, new HarmonyMethod(getTransactionIdMethodReverseMethod)).Patch();
string assemblyName = assembly.GetName().Name;

var tdsParserType = assembly.GetType($"{assemblyName}.TdsParser");
var tdsParserStateObjectType = assembly.GetType($"{assemblyName}.TdsParserStateObject");
var sqlInternalTransactionType = assembly.GetType($"{assemblyName}.SqlInternalTransaction");

var writeMarsHeaderDataMethod = tdsParserType.GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeMarsHeaderDataNewMethod = typeof(Patches).GetMethod(
nameof(Patches.WriteMarsHeaderData),
BindingFlags.Static | BindingFlags.Public
);

var writeShortMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeShortReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteShort),
BindingFlags.Static | BindingFlags.Public
);

var writeLongMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeLongReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteLong),
BindingFlags.Static | BindingFlags.Public
);

var writeIntMethod = tdsParserType.GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Instance | BindingFlags.NonPublic
);

var writeIntReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.WriteInt),
BindingFlags.Static | BindingFlags.Public
);

var incrementAndObtainOpenResultCountMethod = tdsParserStateObjectType.GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Instance | BindingFlags.NonPublic
);

var incrementAndObtainOpenResultCountReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.IncrementAndObtainOpenResultCount),
BindingFlags.Static | BindingFlags.Public
);

var getTransactionIdMethodMethod = sqlInternalTransactionType.GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Instance | BindingFlags.NonPublic
);

var getTransactionIdMethodReverseMethod = typeof(ReversePatches).GetMethod(
nameof(ReversePatches.get_TransactionId),
BindingFlags.Static | BindingFlags.Public
);

MethodPatcher.Patch(writeMarsHeaderDataMethod, new HarmonyMethod(writeMarsHeaderDataNewMethod));

Harmony.ReversePatch(writeShortMethod, new HarmonyMethod(writeShortReverseMethod));
Harmony.ReversePatch(writeLongMethod, new HarmonyMethod(writeLongReverseMethod));
Harmony.ReversePatch(writeIntMethod, new HarmonyMethod(writeIntReverseMethod));
Harmony.ReversePatch(incrementAndObtainOpenResultCountMethod, new HarmonyMethod(incrementAndObtainOpenResultCountReverseMethod));
Harmony.ReversePatch(getTransactionIdMethodMethod, new HarmonyMethod(getTransactionIdMethodReverseMethod));
}
}

/// <summary>
Expand All @@ -115,6 +122,7 @@ private static class Patches
/// https://github.com/dotnet/SqlClient/issues/1623
/// </summary>
/// <remarks>https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L10737</remarks>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool WriteMarsHeaderData(object __instance, long ____retainedTransactionId, object stateObj, object transaction)
{
ReversePatches.WriteShort(__instance, HEADERTYPE_MARS, stateObj);
Expand Down Expand Up @@ -144,41 +152,46 @@ private static class ReversePatches
/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalTransaction.cs#L152
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static long get_TransactionId(object instance)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(get_TransactionId)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs#L931
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static int IncrementAndObtainOpenResultCount(object instance, object transaction)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(IncrementAndObtainOpenResultCount)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1643
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteInt(object instance, int v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteInt)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1721
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteLong(object instance, long v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteLong)}' stub must not be called.");
}

/// <summary>
/// https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs#L1590
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void WriteShort(object instance, int v, object stateObj)
{
throw new InvalidOperationException();
throw new InvalidOperationException($"The '{nameof(WriteShort)}' stub must not be called.");
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/Softwarehelden.Transactions.Oletx/OletxPatcher.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using HarmonyLib;
using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Transactions;

Expand Down Expand Up @@ -89,6 +90,7 @@ private static class Patches
/// <summary>
/// Performs a PSPE enlistment with a non-MSDTC promoter type.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static void EnlistPromotableSinglePhase(
ref Transaction __instance,
ref IPromotableSinglePhaseNotification promotableSinglePhaseNotification,
Expand All @@ -115,6 +117,7 @@ ref Guid promoterType
/// Returns a transaction cookie that is used to propagate/import a distributed
/// transaction on a SQL server that wants to participate in the transaction.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool GetExportCookie(Transaction transaction, byte[] whereabouts, ref byte[] __result)
{
// Should the transaction be promoted using our custom promoter type?
Expand Down Expand Up @@ -164,6 +167,7 @@ public static bool GetExportCookie(Transaction transaction, byte[] whereabouts,
/// <summary>
/// Returns the native DTC transaction for the given transaction instance.
/// </summary>
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.AggressiveOptimization)]
public static bool GetTransactionNative(Transaction transaction, ref IDtcTransaction __result)
{
if (transaction.PromoterType == NonMsdtcPromoterType)
Expand Down

0 comments on commit 8e968e2

Please sign in to comment.