Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'Importer.Import(FieldInfo)' and 'Importer.ImportDeclaringType(Type)' #452

Merged
merged 1 commit into from
Jan 31, 2022

Conversation

wwh1004
Copy link
Contributor

@wwh1004 wwh1004 commented Jan 23, 2022

Linked issue: #445

Test code:

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using dnlib.DotNet;
using dnlib.DotNet.MD;

static class Program {
	static readonly SigComparer FieldComparer = new(SigComparerOptions.DontCompareTypeScope | SigComparerOptions.CompareMethodFieldDeclaringType);
	static readonly SigComparer MethodComparer = new(SigComparerOptions.DontCompareTypeScope | SigComparerOptions.CompareMethodFieldDeclaringType);

	static void Main() {
		Test(typeof(Program).Module);
		Test(typeof(ModuleDef).Module);
		Console.WriteLine("Pass");
		Console.ReadKey(true);
	}

	static void Test(Module module) {
		using var moduleDef = ModuleDefMD.Load(module, ModuleDef.CreateModuleContext());
		var importer1 = new Importer(moduleDef, ImporterOptions.TryToUseExistingAssemblyRefs);
		var importer2 = new Importer(moduleDef, ImporterOptions.TryToUseDefs | ImporterOptions.TryToUseExistingAssemblyRefs);
		foreach (var field in moduleDef.EnumerateMemberRefs().Where(t => t.IsFieldRef).Cast<IField>().Concat(moduleDef.EnumerateFields())) {
			if (field is IContainsGenericParameter cgp && cgp.ContainsGenericParameter)
				continue;

			var fieldInfo = module.ResolveField(field.MDToken.ToInt32());
			var importedField1 = importer1.Import(fieldInfo);
			var importedField2 = importer2.Import(fieldInfo);

			Debug.Assert(field is not IMemberDef || importedField2 is IMemberDef);

			int h1 = FieldComparer.GetHashCode(field);
			int h2 = FieldComparer.GetHashCode(fieldInfo);
			int h3 = FieldComparer.GetHashCode(importedField1);
			int h4 = FieldComparer.GetHashCode(importedField2);
			Debug.Assert(h1 == h2 && h2 == h3 && h3 == h4);
			Debug.Assert(FieldComparer.Equals(field, fieldInfo));
			Debug.Assert(FieldComparer.Equals(field, importedField1));
			Debug.Assert(FieldComparer.Equals(field, importedField2));
			Debug.Assert(FieldComparer.Equals(fieldInfo, importedField1));
			Debug.Assert(FieldComparer.Equals(fieldInfo, importedField2));
			Debug.Assert(FieldComparer.Equals(importedField1, importedField2));
		}
		foreach (var method in moduleDef.EnumerateMemberRefs().Where(t => t.IsMethodRef).Cast<IMethod>().Concat(moduleDef.EnumerateMethods())) {
			if (method is IContainsGenericParameter cgp && cgp.ContainsGenericParameter)
				continue;
			if (method.DeclaringType is TypeDef td && td.IsImport && method.Name.StartsWith("_VtblGap", StringComparison.Ordinal))
				continue;

			var methodInfo = module.ResolveMethod(method.MDToken.ToInt32());
			var importedMethod1 = importer1.Import(methodInfo);
			var importedMethod2 = importer2.Import(methodInfo);

			Debug.Assert(method is not IMemberDef || importedMethod2 is IMemberDef);

			int h1 = MethodComparer.GetHashCode(method);
			int h2 = MethodComparer.GetHashCode(methodInfo);
			int h3 = MethodComparer.GetHashCode(importedMethod1);
			int h4 = MethodComparer.GetHashCode(importedMethod2);
			Debug.Assert(h1 == h2 && h2 == h3 && h3 == h4);
			Debug.Assert(MethodComparer.Equals(method, methodInfo));
			Debug.Assert(MethodComparer.Equals(method, importedMethod1));
			Debug.Assert(MethodComparer.Equals(method, importedMethod2));
			Debug.Assert(MethodComparer.Equals(methodInfo, importedMethod1));
			Debug.Assert(MethodComparer.Equals(methodInfo, importedMethod2));
			Debug.Assert(MethodComparer.Equals(importedMethod1, importedMethod2));
		}
	}

	static class G<T1, T2> where T1 : class {
		static volatile T1 Value1;
		static volatile IDictionary<T1, T2> Value2;

		static void MethodToImport<T3>() where T3 : class {
			Value1 = null;
			Value2 = null;
			G<string, string>.Value1 = null;
			G<string, string>.Value2 = null;
			G<string, T3>.Value1 = null;
			G<string, T3>.Value2 = null;
			G<T3, T3>.Value1 = null;
			G<T3, T3>.Value2 = null;
			G2.Instance = null;
			G2<T3>.Instance = null;
		}

		class G2 {
			public static G2 Instance = new();

			public static G2 Method(G2 g2) { throw null; }
		}

		class G3 {
			public static G2 Instance = new(); // NOT G3!!!, to test MustTreatTypeAsGenericInstType

			public static G2 Method(G2 g2) { throw null; } // NOT G3!!!, to test MustTreatTypeAsGenericInstType
		}

		class G2<T4> {
			public static G2<T4> Instance = new();
		}
	}
}

static class ModuleDefExtensions {
	public static IEnumerable<TypeDef> EnumerateTypes(this ModuleDef module) {
		if (module is ModuleDefMD moduleDefMD) {
			uint typeTableLength = moduleDefMD.TablesStream.TypeDefTable.Rows;
			for (uint rid = 1; rid <= typeTableLength; rid++)
				yield return moduleDefMD.ResolveTypeDef(rid);
		}
		else {
			for (uint rid = 1; ; rid++) {
				if (module.ResolveToken(new MDToken(Table.TypeDef, rid)) is not TypeDef type)
					yield break;
				yield return type;
			}
		}
	}

	public static IEnumerable<FieldDef> EnumerateFields(this ModuleDef module) {
		if (module is ModuleDefMD moduleDefMD) {
			uint fieldTableLength = moduleDefMD.TablesStream.FieldTable.Rows;
			for (uint rid = 1; rid <= fieldTableLength; rid++)
				yield return moduleDefMD.ResolveField(rid);
		}
		else {
			for (uint rid = 1; ; rid++) {
				if (module.ResolveToken(new MDToken(Table.Field, rid)) is not FieldDef field)
					yield break;
				yield return field;
			}
		}
	}

	public static IEnumerable<MethodDef> EnumerateMethods(this ModuleDef module) {
		if (module is ModuleDefMD moduleDefMD) {
			uint methodTableLength = moduleDefMD.TablesStream.MethodTable.Rows;
			for (uint rid = 1; rid <= methodTableLength; rid++)
				yield return moduleDefMD.ResolveMethod(rid);
		}
		else {
			for (uint rid = 1; ; rid++) {
				if (module.ResolveToken(new MDToken(Table.Method, rid)) is not MethodDef method)
					yield break;
				yield return method;
			}
		}
	}

	public static IEnumerable<MemberRef> EnumerateMemberRefs(this ModuleDef module) {
		if (module is ModuleDefMD moduleDefMD) {
			uint memberRefTableLength = moduleDefMD.TablesStream.MemberRefTable.Rows;
			for (uint rid = 1; rid <= memberRefTableLength; rid++)
				yield return moduleDefMD.ResolveMemberRef(rid);
		}
		else {
			for (uint rid = 1; ; rid++) {
				if (module.ResolveToken(new MDToken(Table.MemberRef, rid)) is not MemberRef memberRef)
					yield break;
				yield return memberRef;
			}
		}
	}
}

@wtfsck wtfsck merged commit 9ea8189 into 0xd4d:master Jan 31, 2022
@wtfsck
Copy link
Contributor

wtfsck commented Jan 31, 2022

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants