Skip to content

Commit

Permalink
[OffensiveGC] code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
RemoteNet committed Dec 27, 2024
1 parent c4d9601 commit 266df21
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
31 changes: 15 additions & 16 deletions src/MsvcOffensiveGcHelper/dllmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
}





#define EXPORT __declspec(dllexport)

#define INITIAL_SIZE 10
Expand All @@ -36,13 +33,6 @@ void** trackedAddresses = NULL;
size_t trackedCount = 0;
size_t arraySize = INITIAL_SIZE;

// Pointer to the original free function
void (*OriginalFree)(void* ptr) = NULL;
// Pointer to the original "_free" function
void (*Original_free)(void* ptr) = NULL;
// Pointer to the original "_free_dbg" function
void (*Original_free_dbg)(void* ptr) = NULL;

// ----------------
// Address Tracking
// ----------------
Expand All @@ -51,17 +41,27 @@ void (*Original_free_dbg)(void* ptr) = NULL;
void GrowArrayIfNeeded() {
if (trackedCount >= arraySize) {
arraySize *= 2;
trackedAddresses = (void**)realloc(trackedAddresses, arraySize * sizeof(void*));
void** temp = (void**)realloc(trackedAddresses, arraySize * sizeof(void*));
if (temp == NULL) {
// Handle memory allocation failure
msgboxf("Memory allocation failed while growing the array.");
return;
}
trackedAddresses = temp;
}
}

// Exported method to add an address to the tracking list
EXPORT void AddAddress(void* address) {
if (!trackedAddresses) {
trackedAddresses = (void**)malloc(arraySize * sizeof(void*));
if (!trackedAddresses) {
// Handle memory allocation failure
msgboxf("Memory allocation failed while initializing the array.");
return;
}
}
GrowArrayIfNeeded();
trackedAddresses[trackedCount++] = address;
trackedAddresses[trackedCount++] = address;
}

// Exported method to remove an address from the tracking list
Expand All @@ -85,7 +85,7 @@ void* originalFreeFunctions[MAX_HOOKS] = { nullptr };

// Macro to define the hook function for each index
// garbgeOrDebugArg is an ugly hack to support both free/_free (1 argument) and _free_dbg (2 arguments)
// I'm risking violating the stack and this only works because both my function and the calles use `cdecl`
// I'm risking violating the stack and this only works because both my function and the callees use `cdecl`
#define DEFINE_HOOK_FOR_FREE(index) \
EXPORT void HookForFree##index(void* ptr, void* garbgeOrDebugArg) { \
debugf("[mogHelper] HookForFree"#index" called for ptr = %p\n", ptr); \
Expand All @@ -99,7 +99,7 @@ void* originalFreeFunctions[MAX_HOOKS] = { nullptr };
if (originalFreeFunctions[index] != nullptr) { \
reinterpret_cast<void(*)(void*, void*)>(originalFreeFunctions[index])(ptr, garbgeOrDebugArg); \
} else { \
debugf("[mogHelper][ERROR] HookForFree"#index" could not free ptr = %p because OriginalFree func ptr was null...\n", ptr); \
debugf("[mogHelper][ERROR] HookForFree"#index" could not free ptr = %p because originalFreeFunctions["#index"] func ptr was null...\n", ptr); \
} \
}

Expand Down Expand Up @@ -166,7 +166,6 @@ BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserv
trackedAddresses = NULL;
trackedCount = 0;
arraySize = INITIAL_SIZE;
OriginalFree = NULL;
break;
case DLL_PROCESS_DETACH:
if (trackedAddresses) {
Expand Down
26 changes: 15 additions & 11 deletions src/ScubaDiver/MsvcOffensiveGC.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,32 +132,36 @@ public void HookModules(List<UndecoratedModule> modules)
_alreadyHookedModules.AddRange(modules);
}

public void HookAllFreeFuncs(UndecoratedModule target, List<UndecoratedModule> allModules)
public void HookAllFreeFuncs(UndecoratedModule targetUndecoratedModule, List<UndecoratedModule> allModules)
{
// Make sure our C++ Helper is loaded before accessing anything from the `MsvcOffensiveGcHelper` class
// otherwise the loading the P/Invoke methods will fail on "Failed to load DLL.
// otherwise the loading the P/Invoke methods will fail on "Failed to load DLL".
System.Reflection.Assembly assm = typeof(MsvcOffensiveGC).Assembly;
string assmDir = System.IO.Path.GetDirectoryName(assm.Location);
string helperPath = System.IO.Path.Combine(assmDir, "MsvcOffensiveGcHelper.dll");
var res = PInvoke.LoadLibrary(helperPath);
FreeLibrarySafeHandle res = PInvoke.LoadLibrary(helperPath);
if (res.IsInvalid)
{
throw new Exception($"LoadLibrary failed for {helperPath}");
}

int attemptedFreeFuncs = 0;
foreach (string funcName in new[] { "free", "_free", "_free_dbg" })
{
// Logger.Debug($"[{nameof(MsvcOffensiveGC)}] Starting to hook '{funcName}'s...");
Dictionary<ModuleInfo, DllExport> funcs = FreeFinder.Find(allModules, funcName);
if (funcs.Count == 0)
Dictionary<ModuleInfo, DllExport> freeExportedFunctions = FreeFinder.Find(allModules, funcName);
if (freeExportedFunctions.Count == 0)
{
// Logger.Debug($"[{nameof(MsvcOffensiveGC)}] WARNING! '{funcName}' was not found.");
Logger.Debug($"[{nameof(MsvcOffensiveGC)}] WARNING! '{funcName}' was not found.");
continue;
}
if (funcs.Count > 1)
if (freeExportedFunctions.Count > 1)
{
// Logger.Debug($"[{nameof(MsvcOffensiveGC)}] WARNING! Found '{funcName}' in more then 1 module: " +string.Join(", ", funcs.Keys.Select(a => a.Name).ToArray()));
Logger.Debug($"[{nameof(MsvcOffensiveGC)}] WARNING! Found '{funcName}' in more then 1 module: " + string.Join(", ", freeExportedFunctions.Keys.Select(a => a.Name).ToArray()));
}
foreach (var kvp in funcs)
foreach (var moduleToFreeExport in freeExportedFunctions)
{
DllExport freeFunc = kvp.Value;
DllExport freeFunc = moduleToFreeExport.Value;
// Logger.Debug($"[{nameof(MsvcOffensiveGC)}] Chose '{funcName}' from {kvp.Key.Name}");

// Find out native replacement function for the given func name
Expand All @@ -166,7 +170,7 @@ public void HookAllFreeFuncs(UndecoratedModule target, List<UndecoratedModule> a
// Logger.Debug($"[{nameof(MsvcOffensiveGC)}] Hooking '{funcName}' at 0x{freeFunc.Address:X16} (from {kvp.Key.Name}), " +
//$"in the IAT of {target.Name}. " +
//$"Replacement Address: 0x{replacementPtr:X16}");
ModuleInfo targetModule = target.ModuleInfo;
ModuleInfo targetModule = targetUndecoratedModule.ModuleInfo;
bool replacementRes = Loader.HookIAT((IntPtr)(ulong)targetModule.BaseAddress, (IntPtr)freeFunc.Address, replacementPtr);


Expand Down

0 comments on commit 266df21

Please sign in to comment.