diff --git a/NeosModLoader/AssemblyHider.cs b/NeosModLoader/AssemblyHider.cs index 94bac76..541779f 100644 --- a/NeosModLoader/AssemblyHider.cs +++ b/NeosModLoader/AssemblyHider.cs @@ -3,6 +3,7 @@ using HarmonyLib; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Reflection; @@ -10,9 +11,41 @@ namespace NeosModLoader { internal static class AssemblyHider { + /// + /// Companies that indicate an assembly is part of .NET. + /// This list was found by debug logging the AssemblyCompanyAttribute for all loaded assemblies. + /// + private static HashSet knownDotNetCompanies = new List() + { + "Mono development team", // used by .NET stuff and Mono.Security + }.Select(company => company.ToLower()).ToHashSet(); + + /// + /// Products that indicate an assembly is part of .NET. + /// This list was found by debug logging the AssemblyProductAttribute for all loaded assemblies. + /// + private static HashSet knownDotNetProducts = new List() + { + "Microsoft® .NET", // used by a few System.* assemblies + "Microsoft® .NET Framework", // used by most of the System.* assemblies + "Mono Common Language Infrastructure", // used by mscorlib stuff + }.Select(product => product.ToLower()).ToHashSet(); + + /// + /// Assemblies that were already loaded when NML started up, minus a couple known non-Neos assemblies. + /// private static HashSet? neosAssemblies; + + /// + /// Assemblies that 100% exist due to a mod + /// private static HashSet? modAssemblies; + /// + /// .NET assembiles we want to ignore in some cases, like the callee check for the AppDomain.GetAssemblies() patch + /// + private static HashSet? dotNetAssemblies; + /// /// Patch Neos's type lookup code to not see mod-related types. This is needed, because users can pass /// arbitrary strings to TypeHelper.FindType(), which can be used to detect if someone is running mods. @@ -23,26 +56,28 @@ internal static void PatchNeos(Harmony harmony, HashSet initialAssembl { if (ModLoaderConfiguration.Get().HideModTypes) { + // initialize the static assembly sets that our patches will need later neosAssemblies = GetNeosAssemblies(initialAssemblies); - modAssemblies = GetModAssemblies(); + modAssemblies = GetModAssemblies(neosAssemblies); + dotNetAssemblies = neosAssemblies.Where(LooksLikeDotNetAssembly).ToHashSet(); // TypeHelper.FindType explicitly does a type search - MethodInfo findTypeTarget = AccessTools.DeclaredMethod(typeof(TypeHelper), nameof(TypeHelper.FindType)); + MethodInfo findTypeTarget = AccessTools.DeclaredMethod(typeof(TypeHelper), nameof(TypeHelper.FindType), new Type[] { typeof(string) }); MethodInfo findTypePatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(FindTypePostfix)); harmony.Patch(findTypeTarget, postfix: new HarmonyMethod(findTypePatch)); // WorkerManager.IsValidGenericType checks a type for validity, and if it returns `true` it reveals that the type exists - MethodInfo isValidGenericTypeTarget = AccessTools.DeclaredMethod(typeof(WorkerManager), nameof(WorkerManager.IsValidGenericType)); + MethodInfo isValidGenericTypeTarget = AccessTools.DeclaredMethod(typeof(WorkerManager), nameof(WorkerManager.IsValidGenericType), new Type[] { typeof(Type), typeof(bool) }); MethodInfo isValidGenericTypePatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(IsValidTypePostfix)); harmony.Patch(isValidGenericTypeTarget, postfix: new HarmonyMethod(isValidGenericTypePatch)); // WorkerManager.GetType uses FindType, but upon failure fails back to doing a (strangely) exhausitive reflection-based search for the type - MethodInfo getTypeTarget = AccessTools.DeclaredMethod(typeof(WorkerManager), nameof(WorkerManager.GetType)); + MethodInfo getTypeTarget = AccessTools.DeclaredMethod(typeof(WorkerManager), nameof(WorkerManager.GetType), new Type[] { typeof(string) }); MethodInfo getTypePatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(FindTypePostfix)); harmony.Patch(getTypeTarget, postfix: new HarmonyMethod(getTypePatch)); // FrooxEngine likes to enumerate all types in all assemblies, which is prone to issues (such as crashing FrooxCode if a type isn't loadable) - MethodInfo getAssembliesTarget = AccessTools.DeclaredMethod(typeof(AppDomain), nameof(AppDomain.GetAssemblies)); + MethodInfo getAssembliesTarget = AccessTools.DeclaredMethod(typeof(AppDomain), nameof(AppDomain.GetAssemblies), new Type[] { }); MethodInfo getAssembliesPatch = AccessTools.DeclaredMethod(typeof(AssemblyHider), nameof(GetAssembliesPostfix)); harmony.Patch(getAssembliesTarget, postfix: new HarmonyMethod(getAssembliesPatch)); } @@ -59,14 +94,15 @@ private static HashSet GetNeosAssemblies(HashSet initialAsse return initialAssemblies; } - private static HashSet GetModAssemblies() + private static HashSet GetModAssemblies(HashSet neosAssemblies) { // start with ALL assemblies HashSet assemblies = AppDomain.CurrentDomain.GetAssemblies().ToHashSet(); - // remove assemblies that already existed before NML loaded + // remove assemblies that we know to have come with Neos assemblies.ExceptWith(neosAssemblies); + // what's left are assemblies that magically appeared during the mod loading process. So mods and their dependencies. return assemblies; } @@ -93,7 +129,7 @@ private static bool IsModAssembly(Assembly assembly, string typeOrAssembly, stri // known type from a mod assembly if (log) { - Logger.DebugInternal($"Hid {typeOrAssembly} \"{name}\" from Neos"); + Logger.DebugFuncInternal(() => $"Hid {typeOrAssembly} \"{name}\" from Neos"); } return true; // hide the thing } @@ -165,15 +201,55 @@ private static void IsValidTypePostfix(ref bool __result, Type type) private static void GetAssembliesPostfix(ref Assembly[] __result) { - Assembly? callingAssembly = Util.GetCallingAssembly(); + Assembly? callingAssembly = GetCallingAssembly(new(1)); if (callingAssembly != null && neosAssemblies!.Contains(callingAssembly)) { - // if we're being called by Neos, then hide mod assemblies + // if we're being called by Neos code, then hide mod assemblies Logger.DebugFuncInternal(() => $"Intercepting call to AppDomain.GetAssemblies() from {callingAssembly}"); __result = __result .Where(assembly => !IsModAssembly(assembly, forceShowLate: true)) // it turns out Neos itself late-loads a bunch of stuff, so we force-show late-loaded assemblies here .ToArray(); } } + + /// + /// Get the calling assembly by stack trace analysis, ignoring .NET assemblies. + /// This implementation is SPECIFICALLY for the AppDomain.GetAssemblies() patch and may not be valid for other use-cases. + /// + /// A stack trace captured by the callee + /// The executing assembly, or null if none found + private static Assembly? GetCallingAssembly(StackTrace stackTrace) + { + for (int i = 0; i < stackTrace.FrameCount; i++) + { + Assembly? assembly = stackTrace.GetFrame(i)?.GetMethod()?.DeclaringType?.Assembly; + // .NET calls AppDomain.GetAssemblies() a bunch internally, and we don't want to intercept those calls UNLESS they originated from Neos code. + if (assembly != null && !dotNetAssemblies!.Contains(assembly)) + { + return assembly; + } + } + return null; + } + + private static bool LooksLikeDotNetAssembly(Assembly assembly) + { + // check the assembly's company + string? company = assembly.GetCustomAttribute()?.Company; + if (company != null && knownDotNetCompanies.Contains(company.ToLower())) + { + return true; + } + + // check the assembly's product + string? product = assembly.GetCustomAttribute()?.Product; + if (product != null && knownDotNetProducts.Contains(product.ToLower())) + { + return true; + } + + // nothing matched, this is probably not part of .NET + return false; + } } } diff --git a/NeosModLoader/Logger.cs b/NeosModLoader/Logger.cs index ba54fb3..1ea1176 100644 --- a/NeosModLoader/Logger.cs +++ b/NeosModLoader/Logger.cs @@ -1,5 +1,6 @@ using BaseX; using System; +using System.Diagnostics; namespace NeosModLoader { @@ -25,7 +26,7 @@ internal static void DebugFuncExternal(Func messageProducer) { if (IsDebugEnabled()) { - LogInternal(LogType.DEBUG, messageProducer(), SourceFromStackTrace()); + LogInternal(LogType.DEBUG, messageProducer(), SourceFromStackTrace(new(1))); } } @@ -41,7 +42,7 @@ internal static void DebugExternal(object message) { if (IsDebugEnabled()) { - LogInternal(LogType.DEBUG, message, SourceFromStackTrace()); + LogInternal(LogType.DEBUG, message, SourceFromStackTrace(new(1))); } } @@ -49,19 +50,19 @@ internal static void DebugListExternal(object[] messages) { if (IsDebugEnabled()) { - LogListInternal(LogType.DEBUG, messages, SourceFromStackTrace()); + LogListInternal(LogType.DEBUG, messages, SourceFromStackTrace(new(1))); } } internal static void MsgInternal(string message) => LogInternal(LogType.INFO, message); - internal static void MsgExternal(object message) => LogInternal(LogType.INFO, message, SourceFromStackTrace()); - internal static void MsgListExternal(object[] messages) => LogListInternal(LogType.INFO, messages, SourceFromStackTrace()); + internal static void MsgExternal(object message) => LogInternal(LogType.INFO, message, SourceFromStackTrace(new(1))); + internal static void MsgListExternal(object[] messages) => LogListInternal(LogType.INFO, messages, SourceFromStackTrace(new(1))); internal static void WarnInternal(string message) => LogInternal(LogType.WARN, message); - internal static void WarnExternal(object message) => LogInternal(LogType.WARN, message, SourceFromStackTrace()); - internal static void WarnListExternal(object[] messages) => LogListInternal(LogType.WARN, messages, SourceFromStackTrace()); + internal static void WarnExternal(object message) => LogInternal(LogType.WARN, message, SourceFromStackTrace(new(1))); + internal static void WarnListExternal(object[] messages) => LogListInternal(LogType.WARN, messages, SourceFromStackTrace(new(1))); internal static void ErrorInternal(string message) => LogInternal(LogType.ERROR, message); - internal static void ErrorExternal(object message) => LogInternal(LogType.ERROR, message, SourceFromStackTrace()); - internal static void ErrorListExternal(object[] messages) => LogListInternal(LogType.ERROR, messages, SourceFromStackTrace()); + internal static void ErrorExternal(object message) => LogInternal(LogType.ERROR, message, SourceFromStackTrace(new(1))); + internal static void ErrorListExternal(object[] messages) => LogListInternal(LogType.ERROR, messages, SourceFromStackTrace(new(1))); private static void LogInternal(string logTypePrefix, object message, string? source = null) { @@ -94,10 +95,10 @@ private static void LogListInternal(string logTypePrefix, object[] messages, str } } - private static string? SourceFromStackTrace() + private static string? SourceFromStackTrace(StackTrace stackTrace) { // MsgExternal() and Msg() are above us in the stack - return Util.ExecutingMod(2)?.Name; + return Util.ExecutingMod(stackTrace)?.Name; } private sealed class LogType diff --git a/NeosModLoader/ModConfiguration.cs b/NeosModLoader/ModConfiguration.cs index 737c041..1b1a97a 100644 --- a/NeosModLoader/ModConfiguration.cs +++ b/NeosModLoader/ModConfiguration.cs @@ -527,7 +527,7 @@ internal void Save(bool saveDefaultValues = false, bool immediate = false) { Thread thread = Thread.CurrentThread; - NeosMod? callee = Util.ExecutingMod(); + NeosMod? callee = Util.ExecutingMod(new(1)); Action? saveAction = null; // get saved state for this callee diff --git a/NeosModLoader/NeosVersionReset.cs b/NeosModLoader/NeosVersionReset.cs index b471aa6..3af7550 100644 --- a/NeosModLoader/NeosVersionReset.cs +++ b/NeosModLoader/NeosVersionReset.cs @@ -40,11 +40,13 @@ internal static void Initialize() .Where(IsPostXProcessed) .ToArray(); - string potentialPlugins = postxedAssemblies - .Select(a => Path.GetFileName(a.Location)) - .Join(delimiter: ", "); - - Logger.DebugFuncInternal(() => $"Found {postxedAssemblies.Length} potential plugins: {potentialPlugins}"); + Logger.DebugFuncInternal(() => + { + string potentialPlugins = postxedAssemblies + .Select(a => Path.GetFileName(a.Location)) + .Join(delimiter: ", "); + return $"Found {postxedAssemblies.Length} potential plugins: {potentialPlugins}"; + }); HashSet expectedPostXAssemblies = GetExpectedPostXAssemblies(); @@ -68,8 +70,12 @@ internal static void Initialize() }) .ToArray(); - string actualPlugins = plugins.Keys.Join(delimiter: ", "); - Logger.DebugFuncInternal(() => $"Found {plugins.Count} actual plugins: {actualPlugins}"); + + Logger.DebugFuncInternal(() => + { + string actualPlugins = plugins.Keys.Join(delimiter: ", "); + return $"Found {plugins.Count} actual plugins: {actualPlugins}"; + }); // warn about the assemblies we couldn't map to plugins foreach (Assembly assembly in unmatchedAssemblies) diff --git a/NeosModLoader/Util.cs b/NeosModLoader/Util.cs index f615e2e..7d6c874 100644 --- a/NeosModLoader/Util.cs +++ b/NeosModLoader/Util.cs @@ -13,17 +13,13 @@ namespace NeosModLoader internal static class Util { /// - /// Get the executing mod by stack trace analysis. Always skips the first two frames, being this method and you, the caller. + /// Get the executing mod by stack trace analysis. /// You may skip extra frames if you know your callers are guaranteed to be NML code. /// - /// The number NML method calls above you in the stack + /// A stack trace captured by the callee /// The executing mod, or null if none found - internal static NeosMod? ExecutingMod(int nmlCalleeDepth = 0) + internal static NeosMod? ExecutingMod(StackTrace stackTrace) { - // example: ExecutingMod(), SourceFromStackTrace(), MsgExternal(), Msg(), ACTUAL MOD CODE - // you'd skip 4 frames - // we always skip ExecutingMod() and whoever called us (as this is an internal method), which is where the 2 comes from - StackTrace stackTrace = new(2 + nmlCalleeDepth); for (int i = 0; i < stackTrace.FrameCount; i++) { Assembly? assembly = stackTrace.GetFrame(i)?.GetMethod()?.DeclaringType?.Assembly; @@ -35,26 +31,6 @@ internal static class Util return null; } - /// - /// Get the calling assembly by stack trace analysis. Always skips the first one frame, being this method and you, the caller. - /// - /// The number of extra frame skip in the stack - /// The executing mod, or null if none found - internal static Assembly? GetCallingAssembly(int skipFrames = 0) - { - // same logic as ExecutingMod(), but simpler case - StackTrace stackTrace = new(2 + skipFrames); - for (int i = 0; i < stackTrace.FrameCount; i++) - { - Assembly? assembly = stackTrace.GetFrame(i)?.GetMethod()?.DeclaringType?.Assembly; - if (assembly != null) - { - return assembly; - } - } - return null; - } - /// /// Used to debounce a method call. The underlying method will be called after there have been no additional calls /// for n milliseconds. @@ -142,6 +118,7 @@ private static bool CheckType(Type type, Predicate predicate) { return false; } + try { string _name = type.Name; @@ -158,7 +135,7 @@ private static bool CheckType(Type type, Predicate predicate) } catch (Exception e) { - Logger.DebugFuncInternal(() => $"Could not load type \"{type.Name}\": {e}"); + Logger.DebugFuncInternal(() => $"Could not load type \"{type}\": {e}"); return false; } }