From 230116469357c7bc2f6948241426860fb0dea193 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 16 Feb 2022 13:05:47 -0800 Subject: [PATCH] Forward DefaultDllImportSearchPathsAttribute to the inner P/Invoke. Fixes #65154 --- .../DllImportGenerator/DllImportGenerator.cs | 25 ++++++++--- .../TypeNames.cs | 2 + .../AttributeForwarding.cs | 44 +++++++++++++++++++ 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs index 7d03fd93b12bee..093c27ca88642e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs @@ -197,7 +197,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); } - private static List GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute) + private static List GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute, AttributeData? defaultDllImportSearchPathsAttribute) { const string CallConvsField = "CallConvs"; // Manually rehydrate the forwarded attributes with fully qualified types so we don't have to worry about any using directives. @@ -232,6 +232,15 @@ private static List GenerateSyntaxForForwardedAttributes(Attrib } attributes.Add(unmanagedCallConvSyntax); } + if (defaultDllImportSearchPathsAttribute is not null) + { + attributes.Add( + Attribute(ParseName(TypeNames.DefaultDllImportSearchPathsAttribute)).AddArgumentListArguments( + AttributeArgument( + CastExpression(ParseTypeName(TypeNames.DllImportSearchPath), + LiteralExpression(SyntaxKind.NumericLiteralExpression, + Literal((int)defaultDllImportSearchPathsAttribute.ConstructorArguments[0].Value!)))))); + } return attributes; } @@ -392,11 +401,13 @@ private static IncrementalStubGenerationContext CalculateStubInformation(IMethod INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute); INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute); + INamedTypeSymbol? defaultDllImportSearchPathsAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.DefaultDllImportSearchPathsAttribute); // Get any attributes of interest on the method AttributeData? generatedDllImportAttr = null; AttributeData? lcidConversionAttr = null; AttributeData? suppressGCTransitionAttribute = null; AttributeData? unmanagedCallConvAttribute = null; + AttributeData? defaultDllImportSearchPathsAttribute = null; foreach (AttributeData attr in symbol.GetAttributes()) { if (attr.AttributeClass is not null @@ -404,18 +415,22 @@ private static IncrementalStubGenerationContext CalculateStubInformation(IMethod { generatedDllImportAttr = attr; } - else if (lcidConversionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType)) + else if (lcidConversionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType)) { lcidConversionAttr = attr; } - else if (suppressGCTransitionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType)) + else if (suppressGCTransitionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType)) { suppressGCTransitionAttribute = attr; } - else if (unmanagedCallConvAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType)) + else if (unmanagedCallConvAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType)) { unmanagedCallConvAttribute = attr; } + else if (defaultDllImportSearchPathsAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, defaultDllImportSearchPathsAttrType)) + { + defaultDllImportSearchPathsAttribute = attr; + } } Debug.Assert(generatedDllImportAttr is not null); @@ -434,7 +449,7 @@ private static IncrementalStubGenerationContext CalculateStubInformation(IMethod // Create the stub. var dllImportStub = DllImportStubContext.Create(symbol, stubDllImportData, environment, generatorDiagnostics, ct); - List additionalAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute); + List additionalAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute, defaultDllImportSearchPathsAttribute); return new IncrementalStubGenerationContext(environment, dllImportStub, additionalAttributes.ToImmutableArray(), stubDllImportData, generatorDiagnostics.Diagnostics.ToImmutableArray()); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index 8111b0787bc88a..c2cda25c889978 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -75,5 +75,7 @@ public static string Unsafe(InteropGenerationOptions options) } public const string System_Runtime_CompilerServices_DisableRuntimeMarshallingAttribute = "System.Runtime.CompilerServices.DisableRuntimeMarshallingAttribute"; + public const string DefaultDllImportSearchPathsAttribute = "System.Runtime.InteropServices.DefaultDllImportSearchPathsAttribute"; + public const string DllImportSearchPath = "System.Runtime.InteropServices.DllImportSearchPath"; } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/AttributeForwarding.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/AttributeForwarding.cs index 978bf1c4ca8f79..764929a378dfb2 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/AttributeForwarding.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/AttributeForwarding.cs @@ -201,6 +201,50 @@ public Native(S s) { } callConvType2)); } + [ConditionalFact] + public async Task DefaultDllImportSearchPathsAttribute() + { + string source = @$" +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +[assembly:DisableRuntimeMarshalling] +partial class C +{{ + [DefaultDllImportSearchPaths(DllImportSearchPath.System32 | DllImportSearchPath.UserDirectories)] + [GeneratedDllImportAttribute(""DoesNotExist"")] + public static partial S Method1(); +}} + +[NativeMarshalling(typeof(Native))] +struct S +{{ +}} + +struct Native +{{ + public Native(S s) {{ }} + public S ToManaged() {{ return default; }} +}} +"; + Compilation origComp = await TestUtils.CreateCompilation(source); + Compilation newComp = TestUtils.RunGenerators(origComp, out _, new Microsoft.Interop.DllImportGenerator()); + Assert.Empty(newComp.GetDiagnostics()); + + ITypeSymbol attributeType = newComp.GetTypeByMetadataName("System.Runtime.InteropServices.DefaultDllImportSearchPathsAttribute")!; + + Assert.NotNull(attributeType); + + IMethodSymbol targetMethod = GetGeneratedPInvokeTargetFromCompilation(newComp); + + DllImportSearchPath expected = DllImportSearchPath.System32 | DllImportSearchPath.UserDirectories; + + Assert.Contains( + targetMethod.GetAttributes(), + attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeType) + && attr.ConstructorArguments.Length == 1 + && expected == (DllImportSearchPath)attr.ConstructorArguments[0].Value!); + } + [ConditionalFact] public async Task OtherAttributeType() {