diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs new file mode 100644 index 00000000000000..e1ded7b4ab1842 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Comparers.cs @@ -0,0 +1,95 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; + +namespace Microsoft.Interop +{ + internal static class Comparers + { + /// + /// Comparer for the set of all of the generated stubs and diagnostics generated for each of them. + /// + public static readonly IEqualityComparer)>> GeneratedSourceSet = new ImmutableArraySequenceEqualComparer<(string, ImmutableArray)>(new CustomValueTupleElementComparer>(EqualityComparer.Default, new ImmutableArraySequenceEqualComparer(EqualityComparer.Default))); + + /// + /// Comparer for an individual generated stub source as a string and the generated diagnostics for the stub. + /// + public static readonly IEqualityComparer<(string, ImmutableArray)> GeneratedSource = new CustomValueTupleElementComparer>(EqualityComparer.Default, new ImmutableArraySequenceEqualComparer(EqualityComparer.Default)); + + /// + /// Comparer for an individual generated stub source as a syntax tree and the generated diagnostics for the stub. + /// + public static readonly IEqualityComparer<(MemberDeclarationSyntax Syntax, ImmutableArray Diagnostics)> GeneratedSyntax = new CustomValueTupleElementComparer>(new SyntaxEquivalentComparer(), new ImmutableArraySequenceEqualComparer(EqualityComparer.Default)); + + /// + /// Comparer for the context used to generate a stub and the original user-provided syntax that triggered stub creation. + /// + public static readonly IEqualityComparer<(MethodDeclarationSyntax Syntax, DllImportGenerator.IncrementalStubGenerationContext StubContext)> CalculatedContextWithSyntax = new CustomValueTupleElementComparer(new SyntaxEquivalentComparer(), EqualityComparer.Default); + } + + /// + /// Generic comparer to compare two instances element by element. + /// + /// The type of immutable array element. + internal class ImmutableArraySequenceEqualComparer : IEqualityComparer> + { + private readonly IEqualityComparer elementComparer; + + /// + /// Creates an with a custom comparer for the elements of the collection. + /// + /// The comparer instance for the collection elements. + public ImmutableArraySequenceEqualComparer(IEqualityComparer elementComparer) + { + this.elementComparer = elementComparer; + } + + public bool Equals(ImmutableArray x, ImmutableArray y) + { + return x.SequenceEqual(y, elementComparer); + } + + public int GetHashCode(ImmutableArray obj) + { + throw new UnreachableException(); + } + } + + internal class SyntaxEquivalentComparer : IEqualityComparer + { + public bool Equals(SyntaxNode x, SyntaxNode y) + { + return x.IsEquivalentTo(y); + } + + public int GetHashCode(SyntaxNode obj) + { + throw new UnreachableException(); + } + } + + internal class CustomValueTupleElementComparer : IEqualityComparer<(T, U)> + { + private readonly IEqualityComparer item1Comparer; + private readonly IEqualityComparer item2Comparer; + + public CustomValueTupleElementComparer(IEqualityComparer item1Comparer, IEqualityComparer item2Comparer) + { + this.item1Comparer = item1Comparer; + this.item2Comparer = item2Comparer; + } + + public bool Equals((T, U) x, (T, U) y) + { + return item1Comparer.Equals(x.Item1, y.Item1) && item2Comparer.Equals(x.Item2, y.Item2); + } + + public int GetHashCode((T, U) obj) + { + throw new UnreachableException(); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs index 802858ddb1f51a..92668a646d46ba 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportGenerator.cs @@ -1,164 +1,186 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading; - -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop { [Generator] - public class DllImportGenerator : ISourceGenerator + public class DllImportGenerator : IIncrementalGenerator { private const string GeneratedDllImport = nameof(GeneratedDllImport); private const string GeneratedDllImportAttribute = nameof(GeneratedDllImportAttribute); private static readonly Version MinimumSupportedFrameworkVersion = new Version(5, 0); - public void Execute(GeneratorExecutionContext context) + internal sealed record IncrementalStubGenerationContext(DllImportStubContext StubContext, ImmutableArray ForwardedAttributes, GeneratedDllImportData DllImportData, ImmutableArray Diagnostics) { - if (context.SyntaxContextReceiver is not SyntaxContextReceiver synRec - || !synRec.Methods.Any()) + public bool Equals(IncrementalStubGenerationContext? other) { - return; + return other is not null + && StubContext.Equals(other.StubContext) + && DllImportData.Equals(other.DllImportData) + && ForwardedAttributes.SequenceEqual(other.ForwardedAttributes, (IEqualityComparer)new SyntaxEquivalentComparer()) + && Diagnostics.SequenceEqual(other.Diagnostics); } - INamedTypeSymbol? lcidConversionAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); - - INamedTypeSymbol? suppressGCTransitionAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute); - - INamedTypeSymbol? unmanagedCallConvAttrType = context.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute); - - // Fire the start/stop pair for source generation - using var _ = Diagnostics.Events.SourceGenerationStartStop(synRec.Methods.Count); - - // Store a mapping between SyntaxTree and SemanticModel. - // SemanticModels cache results and since we could be looking at - // method declarations in the same SyntaxTree we want to benefit from - // this caching. - var syntaxToModel = new Dictionary(); - - var generatorDiagnostics = new GeneratorDiagnostics(context); + public override int GetHashCode() + { + throw new UnreachableException(); + } + } - bool isSupported = IsSupportedTargetFramework(context.Compilation, out Version targetFrameworkVersion); - if (!isSupported) + public class IncrementalityTracker + { + public enum StepName { - // We don't return early here, letting the source generation continue. - // This allows a user to copy generated source and use it as a starting point - // for manual marshalling if desired. - generatorDiagnostics.ReportTargetFrameworkNotSupported(MinimumSupportedFrameworkVersion); + CalculateStubInformation, + GenerateSingleStub, + NormalizeWhitespace, + ConcatenateStubs, + OutputSourceFile } - var env = new StubEnvironment( - context.Compilation, - isSupported, - targetFrameworkVersion, - context.AnalyzerConfigOptions.GlobalOptions, - context.Compilation.SourceModule.GetAttributes() - .Any(a => a.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute)); + public record ExecutedStepInfo(StepName Step, object Input); - var generatedDllImports = new StringBuilder(); + private List executedSteps = new(); + public IEnumerable ExecutedSteps => executedSteps; - // Mark in source that the file is auto-generated. - generatedDllImports.AppendLine("// "); + internal void RecordExecutedStep(ExecutedStepInfo step) => executedSteps.Add(step); + } - foreach (SyntaxReference synRef in synRec.Methods) - { - var methodSyntax = (MethodDeclarationSyntax)synRef.GetSyntax(context.CancellationToken); + /// + /// This property provides a test-only hook to enable testing the incrementality of the source generator. + /// This will be removed when https://github.com/dotnet/roslyn/issues/54832 is implemented and can be consumed. + /// + public IncrementalityTracker? IncrementalTracker { get; set; } - // Get the model for the method. - if (!syntaxToModel.TryGetValue(methodSyntax.SyntaxTree, out SemanticModel sm)) + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var methodsToGenerate = context.SyntaxProvider + .CreateSyntaxProvider( + static (node, ct) => ShouldVisitNode(node), + static (context, ct) => + new + { + Syntax = (MethodDeclarationSyntax)context.Node, + Symbol = (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, ct)! + }) + .Where( + static modelData => modelData.Symbol.IsStatic && modelData.Symbol.GetAttributes().Any( + static attribute => attribute.AttributeClass?.ToDisplayString() == TypeNames.GeneratedDllImportAttribute) + ); + + var compilationAndTargetFramework = context.CompilationProvider + .Select(static (compilation, ct) => { - sm = context.Compilation.GetSemanticModel(methodSyntax.SyntaxTree, ignoreAccessibility: true); - syntaxToModel.Add(methodSyntax.SyntaxTree, sm); - } - - // Process the method syntax and get its SymbolInfo. - var methodSymbolInfo = sm.GetDeclaredSymbol(methodSyntax, context.CancellationToken)!; - - // Get any attributes of interest on the method - AttributeData? generatedDllImportAttr = null; - AttributeData? lcidConversionAttr = null; - AttributeData? suppressGCTransitionAttribute = null; - AttributeData? unmanagedCallConvAttribute = null; - - foreach (var attr in methodSymbolInfo.GetAttributes()) + bool isSupported = IsSupportedTargetFramework(compilation, out Version targetFrameworkVersion); + return (compilation, isSupported, targetFrameworkVersion); + }); + + context.RegisterSourceOutput( + compilationAndTargetFramework + .Combine(methodsToGenerate.Collect()), + static (context, data) => { - if (attr.AttributeClass is null) + if (!data.Left.isSupported && data.Right.Any()) { - continue; + // We don't block source generation when the TFM is unsupported. + // This allows a user to copy generated source and use it as a starting point + // for manual marshalling if desired. + context.ReportDiagnostic( + Diagnostic.Create( + GeneratorDiagnostics.TargetFrameworkNotSupported, + Location.None, + MinimumSupportedFrameworkVersion.ToString(2))); } - else if (attr.AttributeClass.ToDisplayString() == TypeNames.GeneratedDllImportAttribute) + }); + + var stubEnvironment = compilationAndTargetFramework + .Combine(context.AnalyzerConfigOptionsProvider) + .Select( + static (data, ct) => + new StubEnvironment( + data.Left.compilation, + data.Left.isSupported, + data.Left.targetFrameworkVersion, + data.Right.GlobalOptions, + data.Left.compilation.SourceModule.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute)) + ); + + var methodSourceAndDiagnostics = methodsToGenerate + .Combine(stubEnvironment) + .Select(static (data, ct) => new + { + data.Left.Syntax, + data.Left.Symbol, + Environment = data.Right + }) + .Select( + (data, ct) => { - generatedDllImportAttr = attr; + IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.CalculateStubInformation, data)); + return (data.Syntax, StubContext: CalculateStubInformation(data.Syntax, data.Symbol, data.Environment, ct)); } - else if (lcidConversionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType)) + ) + .WithComparer(Comparers.CalculatedContextWithSyntax) + .Combine(context.AnalyzerConfigOptionsProvider) + .Select( + (data, ct) => { - lcidConversionAttr = attr; + IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.GenerateSingleStub, data)); + return GenerateSource(data.Left.StubContext, data.Left.Syntax, data.Right.GlobalOptions); } - else if (suppressGCTransitionAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType)) + ) + .WithComparer(Comparers.GeneratedSyntax) + // Handle NormalizeWhitespace as a separate stage for incremental runs since it is an expensive operation. + .Select( + (data, ct) => { - suppressGCTransitionAttribute = attr; - } - else if (unmanagedCallConvAttrType is not null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType)) + IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.NormalizeWhitespace, data)); + return (data.Item1.NormalizeWhitespace().ToFullString(), data.Item2); + }) + .Collect() + .WithComparer(Comparers.GeneratedSourceSet) + .Select((generatedSources, ct) => + { + IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.ConcatenateStubs, generatedSources)); + StringBuilder source = new(); + // Mark in source that the file is auto-generated. + source.AppendLine("// "); + ImmutableArray.Builder diagnostics = ImmutableArray.CreateBuilder(); + foreach (var generated in generatedSources) { - unmanagedCallConvAttribute = attr; + source.AppendLine(generated.Item1); + diagnostics.AddRange(generated.Item2); } - } - - if (generatedDllImportAttr == null) - continue; - - // Process the GeneratedDllImport attribute - DllImportStub.GeneratedDllImportData stubDllImportData = this.ProcessGeneratedDllImportAttribute(generatedDllImportAttr); - Debug.Assert(stubDllImportData is not null); - - if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping)) - { - generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.BestFitMapping)); - } - - if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar)) - { - generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar)); - } - - if (stubDllImportData!.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CallingConvention)) - { - generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr, nameof(DllImportStub.GeneratedDllImportData.CallingConvention)); - } + return (source: source.ToString(), diagnostics: diagnostics.ToImmutable()); + }) + .WithComparer(Comparers.GeneratedSource); - if (lcidConversionAttr != null) + context.RegisterSourceOutput(methodSourceAndDiagnostics, + (context, data) => { - // Using LCIDConversion with GeneratedDllImport is not supported - generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute)); - } - - List forwardedAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute); - - // Create the stub. - var dllImportStub = DllImportStub.Create(methodSymbolInfo, stubDllImportData!, env, generatorDiagnostics, forwardedAttributes, context.CancellationToken); - - PrintGeneratedSource(generatedDllImports, methodSyntax, dllImportStub); - } - - Debug.WriteLine(generatedDllImports.ToString()); // [TODO] Find some way to emit this for debugging - logs? - context.AddSource("DllImportGenerator.g.cs", SourceText.From(generatedDllImports.ToString(), Encoding.UTF8)); - } + IncrementalTracker?.RecordExecutedStep(new IncrementalityTracker.ExecutedStepInfo(IncrementalityTracker.StepName.OutputSourceFile, data)); + foreach (var diagnostic in data.Item2) + { + context.ReportDiagnostic(diagnostic); + } - public void Initialize(GeneratorInitializationContext context) - { - context.RegisterForSyntaxNotifications(() => new SyntaxContextReceiver()); + context.AddSource("GeneratedDllImports.g.cs", data.Item1); + }); } - - private List GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute) + + private static List GenerateSyntaxForForwardedAttributes(AttributeData? suppressGCTransitionAttribute, AttributeData? unmanagedCallConvAttribute) { const string CallConvsField = "CallConvs"; // Manually rehydrate the forwarded attributes with fully qualified types so we don't have to worry about any using directives. @@ -196,7 +218,7 @@ private List GenerateSyntaxForForwardedAttributes(AttributeData return attributes; } - private SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList) + private static SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList) { SyntaxToken[] strippedTokens = new SyntaxToken[tokenList.Count]; for (int i = 0; i < tokenList.Count; i++) @@ -206,7 +228,7 @@ private SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList) return new SyntaxTokenList(strippedTokens); } - private TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration) + private static TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration) { return TypeDeclaration( typeDeclaration.Kind(), @@ -215,17 +237,17 @@ private TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclaration .WithModifiers(typeDeclaration.Modifiers); } - private void PrintGeneratedSource( - StringBuilder builder, + private static MemberDeclarationSyntax PrintGeneratedSource( MethodDeclarationSyntax userDeclaredMethod, - DllImportStub stub) + DllImportStubContext stub, + BlockSyntax stubCode) { // Create stub function var stubMethod = MethodDeclaration(stub.StubReturnType, userDeclaredMethod.Identifier) - .AddAttributeLists(stub.AdditionalAttributes) + .AddAttributeLists(stub.AdditionalAttributes.ToArray()) .WithModifiers(StripTriviaFromModifiers(userDeclaredMethod.Modifiers)) .WithParameterList(ParameterList(SeparatedList(stub.StubParameters))) - .WithBody(stub.StubCode); + .WithBody(stubCode); // Stub should have at least one containing type Debug.Assert(stub.StubContainingTypes.Any()); @@ -250,7 +272,7 @@ private void PrintGeneratedSource( .AddMembers(toPrint); } - builder.AppendLine(toPrint.NormalizeWhitespace().ToString()); + return toPrint; } private static bool IsSupportedTargetFramework(Compilation compilation, out Version version) @@ -270,10 +292,8 @@ private static bool IsSupportedTargetFramework(Compilation compilation, out Vers }; } - private DllImportStub.GeneratedDllImportData ProcessGeneratedDllImportAttribute(AttributeData attrData) + private static GeneratedDllImportData ProcessGeneratedDllImportAttribute(AttributeData attrData) { - var stubDllImportData = new DllImportStub.GeneratedDllImportData(); - // Found the GeneratedDllImport, but it has an error so report the error. // This is most likely an issue with targeting an incorrect TFM. if (attrData.AttributeClass?.TypeKind is null or TypeKind.Error) @@ -282,8 +302,21 @@ private DllImportStub.GeneratedDllImportData ProcessGeneratedDllImportAttribute( throw new InvalidProgramException(); } - // Populate the DllImport data from the GeneratedDllImportAttribute attribute. - stubDllImportData.ModuleName = attrData.ConstructorArguments[0].Value!.ToString(); + + // Default values for these properties are based on the + // documented semanatics of DllImportAttribute: + // - https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute + DllImportMember userDefinedValues = DllImportMember.None; + bool bestFitMapping = false; + CallingConvention callingConvention = CallingConvention.Winapi; + CharSet charSet = CharSet.Ansi; + string? entryPoint = null; + bool exactSpelling = false; // VB has different and unusual default behavior here. + bool preserveSig = true; + bool setLastError = false; + bool throwOnUnmappableChar = false; + + var stubDllImportData = new GeneratedDllImportData(attrData.ConstructorArguments[0].Value!.ToString()); // All other data on attribute is defined as NamedArguments. foreach (var namedArg in attrData.NamedArguments) @@ -293,96 +326,179 @@ private DllImportStub.GeneratedDllImportData ProcessGeneratedDllImportAttribute( default: Debug.Fail($"An unknown member was found on {GeneratedDllImport}"); continue; - case nameof(DllImportStub.GeneratedDllImportData.BestFitMapping): - stubDllImportData.BestFitMapping = (bool)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.BestFitMapping; + case nameof(GeneratedDllImportData.BestFitMapping): + userDefinedValues |= DllImportMember.BestFitMapping; + bestFitMapping = (bool)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.CallingConvention): - stubDllImportData.CallingConvention = (CallingConvention)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CallingConvention; + case nameof(GeneratedDllImportData.CallingConvention): + userDefinedValues |= DllImportMember.CallingConvention; + callingConvention = (CallingConvention)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.CharSet): - stubDllImportData.CharSet = (CharSet)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.CharSet; + case nameof(GeneratedDllImportData.CharSet): + userDefinedValues |= DllImportMember.CharSet; + charSet = (CharSet)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.EntryPoint): - stubDllImportData.EntryPoint = (string)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.EntryPoint; + case nameof(GeneratedDllImportData.EntryPoint): + userDefinedValues |= DllImportMember.EntryPoint; + entryPoint = (string)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.ExactSpelling): - stubDllImportData.ExactSpelling = (bool)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ExactSpelling; + case nameof(GeneratedDllImportData.ExactSpelling): + userDefinedValues |= DllImportMember.ExactSpelling; + exactSpelling = (bool)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.PreserveSig): - stubDllImportData.PreserveSig = (bool)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.PreserveSig; + case nameof(GeneratedDllImportData.PreserveSig): + userDefinedValues |= DllImportMember.PreserveSig; + preserveSig = (bool)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.SetLastError): - stubDllImportData.SetLastError = (bool)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.SetLastError; + case nameof(GeneratedDllImportData.SetLastError): + userDefinedValues |= DllImportMember.SetLastError; + setLastError = (bool)namedArg.Value.Value!; break; - case nameof(DllImportStub.GeneratedDllImportData.ThrowOnUnmappableChar): - stubDllImportData.ThrowOnUnmappableChar = (bool)namedArg.Value.Value!; - stubDllImportData.IsUserDefined |= DllImportStub.DllImportMember.ThrowOnUnmappableChar; + case nameof(GeneratedDllImportData.ThrowOnUnmappableChar): + userDefinedValues |= DllImportMember.ThrowOnUnmappableChar; + throwOnUnmappableChar = (bool)namedArg.Value.Value!; break; } } - return stubDllImportData; + return new GeneratedDllImportData(attrData.ConstructorArguments[0].Value!.ToString()) + { + IsUserDefined = userDefinedValues, + BestFitMapping = bestFitMapping, + CallingConvention = callingConvention, + CharSet = charSet, + EntryPoint = entryPoint, + ExactSpelling = exactSpelling, + PreserveSig = preserveSig, + SetLastError = setLastError, + ThrowOnUnmappableChar = throwOnUnmappableChar + }; } - - private class SyntaxContextReceiver : ISyntaxContextReceiver + private static IncrementalStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct) { - public ICollection Methods { get; } = new List(); - - public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute); + INamedTypeSymbol? suppressGCTransitionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.SuppressGCTransitionAttribute); + INamedTypeSymbol? unmanagedCallConvAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.UnmanagedCallConvAttribute); + // Get any attributes of interest on the method + AttributeData? generatedDllImportAttr = null; + AttributeData? lcidConversionAttr = null; + AttributeData? suppressGCTransitionAttribute = null; + AttributeData? unmanagedCallConvAttribute = null; + foreach (var attr in symbol.GetAttributes()) { - SyntaxNode syntaxNode = context.Node; - - // We only support C# method declarations. - if (syntaxNode.Language != LanguageNames.CSharp - || !syntaxNode.IsKind(SyntaxKind.MethodDeclaration)) + if (attr.AttributeClass is not null + && attr.AttributeClass.ToDisplayString() == TypeNames.GeneratedDllImportAttribute) { - return; + generatedDllImportAttr = attr; } - - var methodSyntax = (MethodDeclarationSyntax)syntaxNode; - - // Verify the method has no generic types or defined implementation - // and is marked static and partial. - if (!(methodSyntax.TypeParameterList is null) - || !(methodSyntax.Body is null) - || !methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) - || !methodSyntax.Modifiers.Any(SyntaxKind.PartialKeyword)) + else if (lcidConversionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, lcidConversionAttrType)) { - return; + lcidConversionAttr = attr; } - - // Verify that the types the method is declared in are marked partial. - for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) + else if (suppressGCTransitionAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, suppressGCTransitionAttrType)) { - if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - return; - } + suppressGCTransitionAttribute = attr; } + else if (unmanagedCallConvAttrType != null && SymbolEqualityComparer.Default.Equals(attr.AttributeClass, unmanagedCallConvAttrType)) + { + unmanagedCallConvAttribute = attr; + } + } + + Debug.Assert(generatedDllImportAttr is not null); + + var generatorDiagnostics = new GeneratorDiagnostics(); + + // Process the GeneratedDllImport attribute + GeneratedDllImportData stubDllImportData = ProcessGeneratedDllImportAttribute(generatedDllImportAttr!); + + if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.BestFitMapping)) + { + generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.BestFitMapping)); + } + + if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.ThrowOnUnmappableChar)) + { + generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.ThrowOnUnmappableChar)); + } + + if (stubDllImportData.IsUserDefined.HasFlag(DllImportMember.CallingConvention)) + { + generatorDiagnostics.ReportConfigurationNotSupported(generatedDllImportAttr!, nameof(GeneratedDllImportData.CallingConvention)); + } + + if (lcidConversionAttr != null) + { + // Using LCIDConversion with GeneratedDllImport is not supported + generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute)); + } + List additionalAttributes = GenerateSyntaxForForwardedAttributes(suppressGCTransitionAttribute, unmanagedCallConvAttribute); - // Check if the method is marked with the GeneratedDllImport attribute. - foreach (AttributeListSyntax listSyntax in methodSyntax.AttributeLists) + // Create the stub. + var dllImportStub = DllImportStubContext.Create(symbol, stubDllImportData, environment, generatorDiagnostics, ct); + + return new IncrementalStubGenerationContext(dllImportStub, additionalAttributes.ToImmutableArray(), stubDllImportData, generatorDiagnostics.Diagnostics.ToImmutableArray()); + } + + private (MemberDeclarationSyntax, ImmutableArray) GenerateSource( + IncrementalStubGenerationContext dllImportStub, + MethodDeclarationSyntax originalSyntax, + AnalyzerConfigOptions options) + { + var diagnostics = new GeneratorDiagnostics(); + + // Generate stub code + var stubGenerator = new StubCodeGenerator( + dllImportStub.DllImportData, + dllImportStub.StubContext.ElementTypeInformation, + options, + (elementInfo, ex) => diagnostics.ReportMarshallingNotSupported(originalSyntax, elementInfo, ex.NotSupportedDetails)); + + ImmutableArray forwardedAttributes = dllImportStub.ForwardedAttributes; + + var code = stubGenerator.GenerateBody(originalSyntax.Identifier.Text, forwardedAttributes: forwardedAttributes.Length != 0 ? AttributeList(SeparatedList(forwardedAttributes)) : null); + + return (PrintGeneratedSource(originalSyntax, dllImportStub.StubContext, code), dllImportStub.Diagnostics.AddRange(diagnostics.Diagnostics)); + } + + private static bool ShouldVisitNode(SyntaxNode syntaxNode) + { + // We only support C# method declarations. + if (syntaxNode.Language != LanguageNames.CSharp + || !syntaxNode.IsKind(SyntaxKind.MethodDeclaration)) + { + return false; + } + + var methodSyntax = (MethodDeclarationSyntax)syntaxNode; + + // Verify the method has no generic types or defined implementation + // and is marked static and partial. + if (methodSyntax.TypeParameterList is not null + || methodSyntax.Body is not null + || !methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword) + || !methodSyntax.Modifiers.Any(SyntaxKind.PartialKeyword)) + { + return false; + } + + // Verify that the types the method is declared in are marked partial. + for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent) + { + if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword)) { - foreach (AttributeSyntax attrSyntax in listSyntax.Attributes) - { - SymbolInfo info = context.SemanticModel.GetSymbolInfo(attrSyntax); - if (info.Symbol is IMethodSymbol attrConstructor - && attrConstructor.ContainingType.ToDisplayString() == TypeNames.GeneratedDllImportAttribute) - { - this.Methods.Add(syntaxNode.GetReference()); - return; - } - } + return false; } } + + // Filter out methods with no attributes early. + if (methodSyntax.AttributeLists.Count == 0) + { + return false; + } + + return true; } } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs similarity index 62% rename from src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs rename to src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs index c82095e5b5a0ee..e86b8fdef826a8 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStub.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/DllImportStubContext.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Linq; using System.Runtime.InteropServices; using System.Threading; @@ -19,101 +20,50 @@ internal record StubEnvironment( AnalyzerConfigOptions Options, bool ModuleSkipLocalsInit); - internal class DllImportStub + internal sealed class DllImportStubContext : IEquatable { - private TypePositionInfo returnTypeInfo; - private IEnumerable paramsTypeInfo; - // We don't need the warnings around not setting the various // non-nullable fields/properties on this type in the constructor // since we always use a property initializer. #pragma warning disable 8618 - private DllImportStub() + private DllImportStubContext() { } #pragma warning restore + public ImmutableArray ElementTypeInformation { get; init; } + public string? StubTypeNamespace { get; init; } - public IEnumerable StubContainingTypes { get; init; } + public ImmutableArray StubContainingTypes { get; init; } - public TypeSyntax StubReturnType { get => this.returnTypeInfo.ManagedType.AsTypeSyntax(); } + public TypeSyntax StubReturnType { get; init; } public IEnumerable StubParameters { get { - foreach (var typeinfo in paramsTypeInfo) + foreach (var typeInfo in ElementTypeInformation) { - if (typeinfo.ManagedIndex != TypePositionInfo.UnsetIndex - && typeinfo.ManagedIndex != TypePositionInfo.ReturnIndex) + if (typeInfo.ManagedIndex != TypePositionInfo.UnsetIndex + && typeInfo.ManagedIndex != TypePositionInfo.ReturnIndex) { - yield return Parameter(Identifier(typeinfo.InstanceIdentifier)) - .WithType(typeinfo.ManagedType.AsTypeSyntax()) - .WithModifiers(TokenList(Token(typeinfo.RefKindSyntax))); + yield return Parameter(Identifier(typeInfo.InstanceIdentifier)) + .WithType(typeInfo.ManagedType.Syntax) + .WithModifiers(TokenList(Token(typeInfo.RefKindSyntax))); } } } } - public BlockSyntax StubCode { get; init; } + public ImmutableArray AdditionalAttributes { get; init; } - public AttributeListSyntax[] AdditionalAttributes { get; init; } - - /// - /// Flags used to indicate members on GeneratedDllImport attribute. - /// - [Flags] - public enum DllImportMember - { - None = 0, - BestFitMapping = 1 << 0, - CallingConvention = 1 << 1, - CharSet = 1 << 2, - EntryPoint = 1 << 3, - ExactSpelling = 1 << 4, - PreserveSig = 1 << 5, - SetLastError = 1 << 6, - ThrowOnUnmappableChar = 1 << 7, - All = ~None - } - - /// - /// GeneratedDllImportAttribute data - /// - /// - /// The names of these members map directly to those on the - /// DllImportAttribute and should not be changed. - /// - public class GeneratedDllImportData - { - public string ModuleName { get; set; } = null!; - - /// - /// Value set by the user on the original declaration. - /// - public DllImportMember IsUserDefined = DllImportMember.None; - - // Default values for the below fields are based on the - // documented semanatics of DllImportAttribute: - // - https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute - public bool BestFitMapping { get; set; } = true; - public CallingConvention CallingConvention { get; set; } = CallingConvention.Winapi; - public CharSet CharSet { get; set; } = CharSet.Ansi; - public string EntryPoint { get; set; } = null!; - public bool ExactSpelling { get; set; } = false; // VB has different and unusual default behavior here. - public bool PreserveSig { get; set; } = true; - public bool SetLastError { get; set; } = false; - public bool ThrowOnUnmappableChar { get; set; } = false; - } - - public static DllImportStub Create( + public static DllImportStubContext Create( IMethodSymbol method, GeneratedDllImportData dllImportData, StubEnvironment env, GeneratorDiagnostics diagnostics, - List forwardedAttributes, - CancellationToken token = default) + CancellationToken token) { // Cancel early if requested token.ThrowIfCancellationRequested(); @@ -127,7 +77,7 @@ public static DllImportStub Create( } // Determine containing type(s) - var containingTypes = new List(); + var containingTypes = ImmutableArray.CreateBuilder(); INamedTypeSymbol currType = method.ContainingType; while (!(currType is null)) { @@ -145,6 +95,36 @@ public static DllImportStub Create( currType = currType.ContainingType; } + var typeInfos = GenerateTypeInformation(method, dllImportData, diagnostics, env); + + var additionalAttrs = ImmutableArray.CreateBuilder(); + + // Define additional attributes for the stub definition. + if (env.TargetFrameworkVersion >= new Version(5, 0) && !MethodIsSkipLocalsInit(env, method)) + { + additionalAttrs.Add( + AttributeList( + SeparatedList(new[] + { + // Adding the skip locals init indiscriminately since the source generator is + // targeted at non-blittable method signatures which typically will contain locals + // in the generated code. + Attribute(ParseName(TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute)) + }))); + } + + return new DllImportStubContext() + { + StubReturnType = method.ReturnType.AsTypeSyntax(), + ElementTypeInformation = typeInfos, + StubTypeNamespace = stubTypeNamespace, + StubContainingTypes = containingTypes.ToImmutable(), + AdditionalAttributes = additionalAttrs.ToImmutable(), + }; + } + + private static ImmutableArray GenerateTypeInformation(IMethodSymbol method, GeneratedDllImportData dllImportData, GeneratorDiagnostics diagnostics, StubEnvironment env) + { // Compute the current default string encoding value. var defaultEncoding = CharEncoding.Undefined; if (dllImportData.IsUserDefined.HasFlag(DllImportMember.CharSet)) @@ -163,21 +143,22 @@ public static DllImportStub Create( var marshallingAttributeParser = new MarshallingAttributeInfoParser(env.Compilation, diagnostics, defaultInfo, method); // Determine parameter and return types - var paramsTypeInfo = new List(); + var typeInfos = ImmutableArray.CreateBuilder(); for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; MarshallingInfo marshallingInfo = marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); - typeInfo = typeInfo with + typeInfo = typeInfo with { ManagedIndex = i, - NativeIndex = paramsTypeInfo.Count + NativeIndex = typeInfos.Count }; - paramsTypeInfo.Add(typeInfo); + typeInfos.Add(typeInfo); + } - TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); + TypePositionInfo retTypeInfo = new(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(method.ReturnType), marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, @@ -190,7 +171,7 @@ public static DllImportStub Create( if (!dllImportData.PreserveSig && !env.Options.GenerateForwarders()) { // Create type info for native HRESULT return - retTypeInfo = TypePositionInfo.CreateForType(env.Compilation.GetSpecialType(SpecialType.System_Int32), NoMarshallingInfo.Instance); + retTypeInfo = new TypePositionInfo(SpecialTypeInfo.Int32, NoMarshallingInfo.Instance); retTypeInfo = retTypeInfo with { NativeIndex = TypePositionInfo.ReturnIndex @@ -206,41 +187,34 @@ public static DllImportStub Create( RefKind = RefKind.Out, RefKindSyntax = SyntaxKind.OutKeyword, ManagedIndex = TypePositionInfo.ReturnIndex, - NativeIndex = paramsTypeInfo.Count + NativeIndex = typeInfos.Count }; - paramsTypeInfo.Add(nativeOutInfo); + typeInfos.Add(nativeOutInfo); } } + typeInfos.Add(retTypeInfo); - // Generate stub code - var stubGenerator = new StubCodeGenerator(method, dllImportData, paramsTypeInfo, retTypeInfo, diagnostics, env.Options); - var code = stubGenerator.GenerateSyntax(forwardedAttributes: forwardedAttributes.Count != 0 ? AttributeList(SeparatedList(forwardedAttributes)) : null); + return typeInfos.ToImmutable(); + } - var additionalAttrs = new List(); + public override bool Equals(object obj) + { + return obj is DllImportStubContext other && Equals(other); + } - // Define additional attributes for the stub definition. - if (env.TargetFrameworkVersion >= new Version(5, 0) && !MethodIsSkipLocalsInit(env, method)) - { - additionalAttrs.Add( - AttributeList( - SeparatedList(new [] - { - // Adding the skip locals init indiscriminately since the source generator is - // targeted at non-blittable method signatures which typically will contain locals - // in the generated code. - Attribute(ParseName(TypeNames.System_Runtime_CompilerServices_SkipLocalsInitAttribute)) - }))); - } + public bool Equals(DllImportStubContext other) + { + return other is not null + && StubTypeNamespace == other.StubTypeNamespace + && ElementTypeInformation.SequenceEqual(other.ElementTypeInformation) + && StubContainingTypes.SequenceEqual(other.StubContainingTypes, (IEqualityComparer)new SyntaxEquivalentComparer()) + && StubReturnType.IsEquivalentTo(other.StubReturnType) + && AdditionalAttributes.SequenceEqual(other.AdditionalAttributes, (IEqualityComparer)new SyntaxEquivalentComparer()); + } - return new DllImportStub() - { - returnTypeInfo = managedRetTypeInfo, - paramsTypeInfo = paramsTypeInfo, - StubTypeNamespace = stubTypeNamespace, - StubContainingTypes = containingTypes, - StubCode = code, - AdditionalAttributes = additionalAttrs.ToArray(), - }; + public override int GetHashCode() + { + throw new UnreachableException(); } private static bool MethodIsSkipLocalsInit(StubEnvironment env, IMethodSymbol method) diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs new file mode 100644 index 00000000000000..9a4fc90a125bfe --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratedDllImportData.cs @@ -0,0 +1,48 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.Interop +{ + /// + /// Flags used to indicate members on GeneratedDllImport attribute. + /// + [Flags] + public enum DllImportMember + { + None = 0, + BestFitMapping = 1 << 0, + CallingConvention = 1 << 1, + CharSet = 1 << 2, + EntryPoint = 1 << 3, + ExactSpelling = 1 << 4, + PreserveSig = 1 << 5, + SetLastError = 1 << 6, + ThrowOnUnmappableChar = 1 << 7, + All = ~None + } + + /// + /// GeneratedDllImportAttribute data + /// + /// + /// The names of these members map directly to those on the + /// DllImportAttribute and should not be changed. + /// + public sealed record GeneratedDllImportData(string ModuleName) + { + /// + /// Value set by the user on the original declaration. + /// + public DllImportMember IsUserDefined { get; init; } + public bool BestFitMapping { get; init; } + public CallingConvention CallingConvention { get; init; } + public CharSet CharSet { get; init; } + public string? EntryPoint { get; init; } + public bool ExactSpelling { get; init; } + public bool PreserveSig { get; init; } + public bool SetLastError { get; init; } + public bool ThrowOnUnmappableChar { get; init; } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs index 51531f6c1a00b9..74abfb3c71497f 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/GeneratorDiagnostics.cs @@ -5,6 +5,7 @@ using System.Linq; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; namespace Microsoft.Interop { @@ -15,15 +16,20 @@ public static Diagnostic CreateDiagnostic( DiagnosticDescriptor descriptor, params object[] args) { - IEnumerable locationsInSource = symbol.Locations.Where(l => l.IsInSource); - if (!locationsInSource.Any()) - return Diagnostic.Create(descriptor, Location.None, args); + return symbol.Locations.CreateDiagnostic(descriptor, args); + } - return Diagnostic.Create( - descriptor, - location: locationsInSource.First(), - additionalLocations: locationsInSource.Skip(1), - messageArgs: args); + public static Diagnostic CreateDiagnostic( + this AttributeData attributeData, + DiagnosticDescriptor descriptor, + params object[] args) + { + SyntaxReference? syntaxReference = attributeData.ApplicationSyntaxReference; + Location location = syntaxReference is not null + ? syntaxReference.GetSyntax().GetLocation() + : Location.None; + + return location.CreateDiagnostic(descriptor, args); } public static Diagnostic CreateDiagnostic( @@ -43,15 +49,10 @@ public static Diagnostic CreateDiagnostic( } public static Diagnostic CreateDiagnostic( - this AttributeData attributeData, + this Location location, DiagnosticDescriptor descriptor, params object[] args) { - SyntaxReference? syntaxReference = attributeData.ApplicationSyntaxReference; - Location location = syntaxReference is not null - ? syntaxReference.GetSyntax().GetLocation() - : Location.None; - return Diagnostic.Create( descriptor, location: location.IsInSource ? location : Location.None, @@ -174,12 +175,9 @@ public class Ids isEnabledByDefault: true, description: GetResourceString(nameof(Resources.TargetFrameworkNotSupportedDescription))); - private readonly GeneratorExecutionContext context; + private readonly List diagnostics = new List(); - public GeneratorDiagnostics(GeneratorExecutionContext context) - { - this.context = context; - } + public IEnumerable Diagnostics => diagnostics; /// /// Report diagnostic for configuration that is not supported by the DLL import source generator @@ -194,14 +192,14 @@ public void ReportConfigurationNotSupported( { if (unsupportedValue == null) { - this.context.ReportDiagnostic( + diagnostics.Add( attributeData.CreateDiagnostic( GeneratorDiagnostics.ConfigurationNotSupported, configurationName)); } else { - this.context.ReportDiagnostic( + diagnostics.Add( attributeData.CreateDiagnostic( GeneratorDiagnostics.ConfigurationValueNotSupported, unsupportedValue, @@ -216,30 +214,44 @@ public void ReportConfigurationNotSupported( /// Type info for the parameter/return /// [Optional] Specific reason for lack of support internal void ReportMarshallingNotSupported( - IMethodSymbol method, + MethodDeclarationSyntax method, TypePositionInfo info, string? notSupportedDetails) { + Location diagnosticLocation = Location.None; + string elementName = string.Empty; + + if (info.IsManagedReturnPosition) + { + diagnosticLocation = Location.Create(method.SyntaxTree, method.Identifier.Span); + elementName = method.Identifier.ValueText; + } + else + { + Debug.Assert(info.ManagedIndex <= method.ParameterList.Parameters.Count); + ParameterSyntax param = method.ParameterList.Parameters[info.ManagedIndex]; + diagnosticLocation = Location.Create(param.SyntaxTree, param.Identifier.Span); + elementName = param.Identifier.ValueText; + } + if (!string.IsNullOrEmpty(notSupportedDetails)) { // Report the specific not-supported reason. if (info.IsManagedReturnPosition) { - this.context.ReportDiagnostic( - method.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails, notSupportedDetails!, - method.Name)); + elementName)); } else { - Debug.Assert(info.ManagedIndex <= method.Parameters.Length); - IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex]; - this.context.ReportDiagnostic( - paramSymbol.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails, notSupportedDetails!, - paramSymbol.Name)); + elementName)); } } else if (info.MarshallingAttributeInfo is MarshalAsInfo) @@ -249,21 +261,19 @@ internal void ReportMarshallingNotSupported( // than when there is no attribute and the type itself is not supported. if (info.IsManagedReturnPosition) { - this.context.ReportDiagnostic( - method.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ReturnConfigurationNotSupported, nameof(System.Runtime.InteropServices.MarshalAsAttribute), - method.Name)); + elementName)); } else { - Debug.Assert(info.ManagedIndex <= method.Parameters.Length); - IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex]; - this.context.ReportDiagnostic( - paramSymbol.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ParameterConfigurationNotSupported, nameof(System.Runtime.InteropServices.MarshalAsAttribute), - paramSymbol.Name)); + elementName)); } } else @@ -271,21 +281,19 @@ internal void ReportMarshallingNotSupported( // Report that the type is not supported if (info.IsManagedReturnPosition) { - this.context.ReportDiagnostic( - method.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ReturnTypeNotSupported, - method.ReturnType.ToDisplayString(), - method.Name)); + info.ManagedType.DiagnosticFormattedName, + elementName)); } else { - Debug.Assert(info.ManagedIndex <= method.Parameters.Length); - IParameterSymbol paramSymbol = method.Parameters[info.ManagedIndex]; - this.context.ReportDiagnostic( - paramSymbol.CreateDiagnostic( + diagnostics.Add( + diagnosticLocation.CreateDiagnostic( GeneratorDiagnostics.ParameterTypeNotSupported, - paramSymbol.Type.ToDisplayString(), - paramSymbol.Name)); + info.ManagedType.DiagnosticFormattedName, + elementName)); } } } @@ -295,7 +303,7 @@ internal void ReportInvalidMarshallingAttributeInfo( string reasonResourceName, params string[] reasonArgs) { - this.context.ReportDiagnostic( + diagnostics.Add( attributeData.CreateDiagnostic( GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported, new LocalizableResourceString(reasonResourceName, Resources.ResourceManager, typeof(Resources), reasonArgs))); @@ -307,7 +315,7 @@ internal void ReportInvalidMarshallingAttributeInfo( /// Minimum supported version of .NET public void ReportTargetFrameworkNotSupported(Version minimumSupportedVersion) { - this.context.ReportDiagnostic( + diagnostics.Add( Diagnostic.Create( TargetFrameworkNotSupported, Location.None, diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs new file mode 100644 index 00000000000000..9e2bd6f70d8a4d --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/ManagedTypeInfo.cs @@ -0,0 +1,74 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Interop +{ + /// + /// A discriminated union that contains enough info about a managed type to determine a marshalling generator and generate code. + /// + internal abstract record ManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName) + { + public TypeSyntax Syntax { get; } = SyntaxFactory.ParseTypeName(FullTypeName); + + public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type) + { + string typeName = type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + string diagonsticFormattedName = type.ToDisplayString(); + if (type.SpecialType != SpecialType.None) + { + return new SpecialTypeInfo(typeName, diagonsticFormattedName, type.SpecialType); + } + if (type.TypeKind == TypeKind.Enum) + { + return new EnumTypeInfo(typeName, diagonsticFormattedName, ((INamedTypeSymbol)type).EnumUnderlyingType!.SpecialType); + } + if (type.TypeKind == TypeKind.Pointer) + { + return new PointerTypeInfo(typeName, diagonsticFormattedName, IsFunctionPointer: false); + } + if (type.TypeKind == TypeKind.FunctionPointer) + { + return new PointerTypeInfo(typeName, diagonsticFormattedName, IsFunctionPointer: true); + } + if (type.TypeKind == TypeKind.Array && type is IArrayTypeSymbol { IsSZArray: true } arraySymbol) + { + return new SzArrayType(CreateTypeInfoForTypeSymbol(arraySymbol.ElementType)); + } + if (type.TypeKind == TypeKind.Delegate) + { + return new DelegateTypeInfo(typeName, diagonsticFormattedName); + } + return new SimpleManagedTypeInfo(typeName, diagonsticFormattedName); + } + } + + internal sealed record SpecialTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType SpecialType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName) + { + public static readonly SpecialTypeInfo Int32 = new("int", "int", SpecialType.System_Int32); + public static readonly SpecialTypeInfo Void = new("void", "void", SpecialType.System_Void); + + public bool Equals(SpecialTypeInfo? other) + { + return other is not null && SpecialType == other.SpecialType; + } + + public override int GetHashCode() + { + return (int)SpecialType; + } + } + + internal sealed record EnumTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType UnderlyingType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); + + internal sealed record PointerTypeInfo(string FullTypeName, string DiagnosticFormattedName, bool IsFunctionPointer) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); + + internal sealed record SzArrayType(ManagedTypeInfo ElementTypeInfo) : ManagedTypeInfo($"{ElementTypeInfo.FullTypeName}[]", $"{ElementTypeInfo.DiagnosticFormattedName}[]"); + + internal sealed record DelegateTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); + + internal sealed record SimpleManagedTypeInfo(string FullTypeName, string DiagnosticFormattedName) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName); +} diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs index 0f158706cd038f..15fa7774a996f2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BlittableMarshaller.cs @@ -11,7 +11,7 @@ internal class BlittableMarshaller : IMarshallingGenerator { public TypeSyntax AsNativeType(TypePositionInfo info) { - return info.ManagedType.AsTypeSyntax(); + return info.ManagedType.Syntax; } public ParameterSyntax AsParameter(TypePositionInfo info) diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs index b659ccbda4b785..2b5af8d2a5cd80 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/BoolMarshaller.cs @@ -25,7 +25,7 @@ protected BoolMarshallerBase(PredefinedTypeSyntax nativeType, int trueValue, int public TypeSyntax AsNativeType(TypePositionInfo info) { - Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Boolean); + Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Boolean)); return _nativeType; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs index b3279390da77f9..d04a78c2460ae7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/CharMarshaller.cs @@ -33,7 +33,7 @@ public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) public TypeSyntax AsNativeType(TypePositionInfo info) { - Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Char); + Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Char)); return NativeType; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs index 21a5c80d4cca9a..b5aca0f44207f7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/DelegateMarshaller.cs @@ -89,7 +89,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont .WithTypeArgumentList( TypeArgumentList( SingletonSeparatedList( - info.ManagedType.AsTypeSyntax())))), + info.ManagedType.Syntax)))), ArgumentList(SingletonSeparatedList(Argument(IdentifierName(nativeIdentifier))))), LiteralExpression(SyntaxKind.NullLiteralExpression)))); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs index 1bd26150668f68..e86c422ec48289 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/Forwarder.cs @@ -12,7 +12,7 @@ internal class Forwarder : IMarshallingGenerator, IAttributedReturnTypeMarshalli { public TypeSyntax AsNativeType(TypePositionInfo info) { - return info.ManagedType.AsTypeSyntax(); + return info.ManagedType.Syntax; } private bool TryRehydrateMarshalAsAttribute(TypePositionInfo info, out AttributeSyntax marshalAsAttribute) @@ -87,7 +87,7 @@ public ParameterSyntax AsParameter(TypePositionInfo info) { ParameterSyntax param = Parameter(Identifier(info.InstanceIdentifier)) .WithModifiers(TokenList(Token(info.RefKindSyntax))) - .WithType(info.ManagedType.AsTypeSyntax()); + .WithType(info.ManagedType.Syntax); if (TryRehydrateMarshalAsAttribute(info, out AttributeSyntax marshalAsAttribute)) { diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs index 4b56c8a3107253..09ce81431e6a93 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/HResultExceptionMarshaller.cs @@ -15,7 +15,7 @@ internal sealed class HResultExceptionMarshaller : IMarshallingGenerator public TypeSyntax AsNativeType(TypePositionInfo info) { - Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Int32); + Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Int32)); return NativeType; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 946e96fe56220f..24fb56ad90bcf9 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -194,31 +194,31 @@ private static IMarshallingGenerator CreateCore( if (info.IsNativeReturnPosition && !info.IsManagedReturnPosition) { // Use marshaller for native HRESULT return / exception throwing - System.Diagnostics.Debug.Assert(info.ManagedType.SpecialType == SpecialType.System_Int32); + System.Diagnostics.Debug.Assert(info.ManagedType is SpecialTypeInfo { SpecialType: SpecialType.System_Int32 }); return HResultException; } switch (info) { // Blittable primitives with no marshalling info or with a compatible [MarshalAs] attribute. - case { ManagedType: { SpecialType: SpecialType.System_SByte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I1, _) } - or { ManagedType: { SpecialType: SpecialType.System_Byte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U1, _) } - or { ManagedType: { SpecialType: SpecialType.System_Int16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I2, _) } - or { ManagedType: { SpecialType: SpecialType.System_UInt16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U2, _) } - or { ManagedType: { SpecialType: SpecialType.System_Int32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I4, _) } - or { ManagedType: { SpecialType: SpecialType.System_UInt32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U4, _) } - or { ManagedType: { SpecialType: SpecialType.System_Int64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I8, _) } - or { ManagedType: { SpecialType: SpecialType.System_UInt64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U8, _) } - or { ManagedType: { SpecialType: SpecialType.System_IntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysInt, _) } - or { ManagedType: { SpecialType: SpecialType.System_UIntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysUInt, _) } - or { ManagedType: { SpecialType: SpecialType.System_Single }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R4, _) } - or { ManagedType: { SpecialType: SpecialType.System_Double }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R8, _) }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_SByte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I1, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Byte }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U1, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I2, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt16 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U2, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I4, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt32 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U4, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Int64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.I8, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UInt64 }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.U8, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_IntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysInt, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_UIntPtr }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.SysUInt, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Single }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R4, _) } + or { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Double }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.R8, _) }: return Blittable; // Enum with no marshalling info - case { ManagedType: { TypeKind: TypeKind.Enum }, MarshallingAttributeInfo: NoMarshallingInfo }: + case { ManagedType: EnumTypeInfo enumType, MarshallingAttributeInfo: NoMarshallingInfo }: // Check that the underlying type is not bool or char. C# does not allow this, but ECMA-335 does. - var underlyingSpecialType = ((INamedTypeSymbol)info.ManagedType).EnumUnderlyingType!.SpecialType; + var underlyingSpecialType = enumType.UnderlyingType; if (underlyingSpecialType == SpecialType.System_Boolean || underlyingSpecialType == SpecialType.System_Char) { throw new MarshallingNotSupportedException(info, context); @@ -226,31 +226,31 @@ private static IMarshallingGenerator CreateCore( return Blittable; // Pointer with no marshalling info - case { ManagedType: { TypeKind: TypeKind.Pointer }, MarshallingAttributeInfo: NoMarshallingInfo }: + case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer:false), MarshallingAttributeInfo: NoMarshallingInfo }: return Blittable; // Function pointer with no marshalling info - case { ManagedType: { TypeKind: TypeKind.FunctionPointer }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }: + case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: true), MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }: return Blittable; - case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: NoMarshallingInfo }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: NoMarshallingInfo }: return WinBool; // [Compat] Matching the default for the built-in runtime marshallers. - case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I1 or UnmanagedType.U1, _) }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I1 or UnmanagedType.U1, _) }: return ByteBool; - case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I4 or UnmanagedType.U4 or UnmanagedType.Bool, _) }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.I4 or UnmanagedType.U4 or UnmanagedType.Bool, _) }: return WinBool; - case { ManagedType: { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.VariantBool, _) }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Boolean }, MarshallingAttributeInfo: MarshalAsInfo(UnmanagedType.VariantBool, _) }: return VariantBool; - case { ManagedType: { TypeKind: TypeKind.Delegate }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }: + case { ManagedType: DelegateTypeInfo, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }: return Delegate; - case { MarshallingAttributeInfo: SafeHandleMarshallingInfo }: + case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }: if (!context.AdditionalTemporaryStateLivesAcrossStages) { throw new MarshallingNotSupportedException(info, context); } - if (info.IsByRef && info.ManagedType.IsAbstract) + if (info.IsByRef && isAbstract) { throw new MarshallingNotSupportedException(info, context) { @@ -274,13 +274,13 @@ private static IMarshallingGenerator CreateCore( // Cases that just match on type must come after the checks that match only on marshalling attribute info. // The checks below do not account for generic marshalling overrides like [MarshalUsing], so those checks must come first. - case { ManagedType: { SpecialType: SpecialType.System_Char } }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Char } }: return CreateCharMarshaller(info, context); - case { ManagedType: { SpecialType: SpecialType.System_String } }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_String } }: return CreateStringMarshaller(info, context); - case { ManagedType: { SpecialType: SpecialType.System_Void } }: + case { ManagedType: SpecialTypeInfo { SpecialType: SpecialType.System_Void } }: return Forwarder; default: @@ -403,7 +403,7 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo) paramInfo, out int numIndirectionLevels); - ITypeSymbol type = paramInfo.ManagedType; + ManagedTypeInfo type = paramInfo.ManagedType; MarshallingInfo marshallingInfo = paramInfo.MarshallingAttributeInfo; for (int i = 0; i < numIndirectionLevels; i++) @@ -422,7 +422,7 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo) } } - if (!type.IsIntegralType()) + if (type is not SpecialTypeInfo specialType || !specialType.SpecialType.IsIntegralType()) { throw new MarshallingNotSupportedException(info, context) { @@ -470,14 +470,14 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi { ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo); - ICustomNativeTypeMarshallingStrategy marshallingStrategy = new SimpleCustomNativeTypeMarshalling(marshalInfo.NativeMarshallingType.AsTypeSyntax()); + ICustomNativeTypeMarshallingStrategy marshallingStrategy = new SimpleCustomNativeTypeMarshalling(marshalInfo.NativeMarshallingType.Syntax); - if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0) + if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNativeStackalloc) != 0) { marshallingStrategy = new StackallocOptimizationMarshalling(marshallingStrategy); } - if (ManualTypeMarshallingHelper.HasFreeNativeMethod(marshalInfo.NativeMarshallingType)) + if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.FreeNativeResources) != 0) { marshallingStrategy = new FreeNativeCleanupStrategy(marshallingStrategy); } @@ -495,7 +495,7 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); - if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) + if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedTypePinning) != 0) { return new PinnableManagedValueMarshaller(marshallingGenerator); } @@ -508,53 +508,54 @@ private static void ValidateCustomNativeTypeMarshallingSupported(TypePositionInf // The marshalling method for this type doesn't support marshalling from native to managed, // but our scenario requires marshalling from native to managed. if ((info.RefKind == RefKind.Ref || info.RefKind == RefKind.Out || info.IsManagedReturnPosition) - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.NativeToManaged) == 0) + && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.NativeToManaged) == 0) { throw new MarshallingNotSupportedException(info, context) { - NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingNativeToManagedUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) + NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingNativeToManagedUnsupported, marshalInfo.NativeMarshallingType.FullTypeName) }; } // The marshalling method for this type doesn't support marshalling from managed to native by value, // but our scenario requires marshalling from managed to native by value. else if (!info.IsByRef - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 - && (context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingMethods & (SupportedMarshallingMethods.Pinning | SupportedMarshallingMethods.ManagedToNativeStackalloc)) == 0)) + && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0 + && (context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingFeatures & (CustomMarshallingFeatures.ManagedTypePinning | CustomMarshallingFeatures.ManagedToNativeStackalloc)) == 0)) { throw new MarshallingNotSupportedException(info, context) { - NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) + NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName) }; } // The marshalling method for this type doesn't support marshalling from managed to native by reference, // but our scenario requires marshalling from managed to native by reference. // "in" byref supports stack marshalling. else if (info.RefKind == RefKind.In - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 - && !(context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0)) + && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0 + && !(context.SingleFrameSpansNativeContext && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNativeStackalloc) != 0)) { throw new MarshallingNotSupportedException(info, context) { - NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) + NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName) }; } // The marshalling method for this type doesn't support marshalling from managed to native by reference, // but our scenario requires marshalling from managed to native by reference. // "ref" byref marshalling doesn't support stack marshalling else if (info.RefKind == RefKind.Ref - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0) + && (marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedToNative) == 0) { throw new MarshallingNotSupportedException(info, context) { - NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) + NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.FullTypeName) }; } } private static ICustomNativeTypeMarshallingStrategy DecorateWithValuePropertyStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) { - TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.AsTypeSyntax(); - if (ManualTypeMarshallingHelper.FindGetPinnableReference(marshalInfo.NativeMarshallingType) is not null) + TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.Syntax; + + if ((marshalInfo.MarshallingFeatures & CustomMarshallingFeatures.NativeTypePinning) != 0) { return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, valuePropertyTypeSyntax); } @@ -569,7 +570,7 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller( AnalyzerConfigOptions options, ICustomNativeTypeMarshallingStrategy marshallingStrategy) { - var elementInfo = TypePositionInfo.CreateForType(collectionInfo.ElementType, collectionInfo.ElementMarshallingInfo) with { ManagedIndex = info.ManagedIndex }; + var elementInfo = new TypePositionInfo(collectionInfo.ElementType, collectionInfo.ElementMarshallingInfo) { ManagedIndex = info.ManagedIndex }; var elementMarshaller = Create( elementInfo, new ContiguousCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, context), @@ -580,7 +581,7 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller( if (isBlittable) { - marshallingStrategy = new ContiguousBlittableElementCollectionMarshalling(marshallingStrategy, collectionInfo.ElementType.AsTypeSyntax()); + marshallingStrategy = new ContiguousBlittableElementCollectionMarshalling(marshallingStrategy, collectionInfo.ElementType.Syntax); } else { @@ -605,7 +606,7 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller( numElementsExpression, SizeOfExpression(elementType)); - if (collectionInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true }) + if (collectionInfo.UseDefaultMarshalling && info.ManagedType is SzArrayType) { return new ArrayMarshaller( new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: true), @@ -616,7 +617,7 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller( IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); - if ((collectionInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) + if ((collectionInfo.MarshallingFeatures & CustomMarshallingFeatures.ManagedTypePinning) != 0) { return new PinnableManagedValueMarshaller(marshallingGenerator); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 441d2ce758b7f7..e3b32f94a5c6c4 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -83,9 +83,9 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } var safeHandleCreationExpression = ((SafeHandleMarshallingInfo)info.MarshallingAttributeInfo).AccessibleDefaultConstructor - ? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.AsTypeSyntax(), ArgumentList(), initializer: null) + ? (ExpressionSyntax)ObjectCreationExpression(info.ManagedType.Syntax, ArgumentList(), initializer: null) : CastExpression( - info.ManagedType.AsTypeSyntax(), + info.ManagedType.Syntax, InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, @@ -97,7 +97,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont new []{ Argument( TypeOfExpression( - info.ManagedType.AsTypeSyntax())), + info.ManagedType.Syntax)), Argument( LiteralExpression( SyntaxKind.TrueLiteralExpression)) @@ -121,7 +121,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont // leak the handle if we failed to create the handle. yield return LocalDeclarationStatement( VariableDeclaration( - info.ManagedType.AsTypeSyntax(), + info.ManagedType.Syntax, SingletonSeparatedList( VariableDeclarator(newHandleObjectIdentifier) .WithInitializer(EqualsValueClause(safeHandleCreationExpression))))); diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs index d69f2ea1a5855a..96988216005cc6 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/MarshallingAttributeInfo.cs @@ -67,14 +67,15 @@ internal sealed record MarshalAsInfo( internal sealed record BlittableTypeAttributeInfo : MarshallingInfo; [Flags] - internal enum SupportedMarshallingMethods + internal enum CustomMarshallingFeatures { None = 0, ManagedToNative = 0x1, NativeToManaged = 0x2, ManagedToNativeStackalloc = 0x4, - Pinning = 0x8, - All = -1 + ManagedTypePinning = 0x8, + NativeTypePinning = 0x10, + FreeNativeResources = 0x20, } internal abstract record CountInfo; @@ -106,10 +107,9 @@ internal sealed record SizeAndParamIndexInfo(int ConstSize, TypePositionInfo? Pa /// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute /// internal record NativeMarshallingAttributeInfo( - ITypeSymbol NativeMarshallingType, - ITypeSymbol? ValuePropertyType, - SupportedMarshallingMethods MarshallingMethods, - bool NativeTypePinnable, + ManagedTypeInfo NativeMarshallingType, + ManagedTypeInfo? ValuePropertyType, + CustomMarshallingFeatures MarshallingFeatures, bool UseDefaultMarshalling) : MarshallingInfo; /// @@ -122,24 +122,22 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo( /// /// The type of the element is a SafeHandle-derived type with no marshalling attributes. /// - internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo; + internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor, bool IsAbstract) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute /// with a contiguous collection marshaller internal sealed record NativeContiguousCollectionMarshallingInfo( - ITypeSymbol NativeMarshallingType, - ITypeSymbol? ValuePropertyType, - SupportedMarshallingMethods MarshallingMethods, - bool NativeTypePinnable, + ManagedTypeInfo NativeMarshallingType, + ManagedTypeInfo? ValuePropertyType, + CustomMarshallingFeatures MarshallingFeatures, bool UseDefaultMarshalling, CountInfo ElementCountInfo, - ITypeSymbol ElementType, + ManagedTypeInfo ElementType, MarshallingInfo ElementMarshallingInfo) : NativeMarshallingAttributeInfo( NativeMarshallingType, ValuePropertyType, - MarshallingMethods, - NativeTypePinnable, + MarshallingFeatures, UseDefaultMarshalling ); @@ -407,8 +405,8 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashSet inspectedElements, ref int maxIndirectionLevelUsed) { - SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; + CustomMarshallingFeatures features = CustomMarshallingFeatures.None; if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) { - methods |= SupportedMarshallingMethods.Pinning; + features |= CustomMarshallingFeatures.ManagedTypePinning; } ITypeSymbol spanOfByte = _compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(_compilation.GetSpecialType(SpecialType.System_Byte)); @@ -611,12 +610,12 @@ MarshallingInfo CreateNativeMarshallingInfo( { if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, marshallingVariant) && (valueProperty is null or { GetMethod: not null })) { - methods |= SupportedMarshallingMethods.ManagedToNative; + features |= CustomMarshallingFeatures.ManagedToNative; } else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, marshallingVariant) && (valueProperty is null or { GetMethod: not null })) { - methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; + features |= CustomMarshallingFeatures.ManagedToNativeStackalloc; } else if (ctor.Parameters.Length == 1 && ctor.Parameters[0].Type.SpecialType == SpecialType.System_Int32) { @@ -631,10 +630,10 @@ MarshallingInfo CreateNativeMarshallingInfo( && ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) && (valueProperty is null or { SetMethod: not null })) { - methods |= SupportedMarshallingMethods.NativeToManaged; + features |= CustomMarshallingFeatures.NativeToManaged; } - if (methods == SupportedMarshallingMethods.None) + if (features == CustomMarshallingFeatures.None) { _diagnostics.ReportInvalidMarshallingAttributeInfo( attrData, @@ -645,6 +644,16 @@ MarshallingInfo CreateNativeMarshallingInfo( return NoMarshallingInfo.Instance; } + if (ManualTypeMarshallingHelper.HasFreeNativeMethod(nativeType)) + { + features |= CustomMarshallingFeatures.FreeNativeResources; + } + + if (ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null) + { + features |= CustomMarshallingFeatures.NativeTypePinning; + } + if (isContiguousCollectionMarshaller) { if (!ManualTypeMarshallingHelper.HasNativeValueStorageProperty(nativeType, spanOfByte)) @@ -660,21 +669,19 @@ MarshallingInfo CreateNativeMarshallingInfo( } return new NativeContiguousCollectionMarshallingInfo( - nativeType, - valueProperty?.Type, - methods, - NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), + valueProperty is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valueProperty.Type) : null, + features, UseDefaultMarshalling: !isMarshalUsingAttribute, parsedCountInfo, - elementType, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType), GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); } return new NativeMarshallingAttributeInfo( - nativeType, - valueProperty?.Type, - methods, - NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + ManagedTypeInfo.CreateTypeInfoForTypeSymbol(nativeType), + valueProperty is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valueProperty.Type) : null, + features, UseDefaultMarshalling: !isMarshalUsingAttribute); } @@ -705,7 +712,7 @@ bool TryCreateTypeBasedMarshallingInfo( } } } - marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor); + marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor, type.IsAbstract); return true; } @@ -729,14 +736,15 @@ bool TryCreateTypeBasedMarshallingInfo( return false; } + ITypeSymbol? valuePropertyType = ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type; + marshallingInfo = new NativeContiguousCollectionMarshallingInfo( - NativeMarshallingType: arrayMarshaller, - ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, - MarshallingMethods: ~SupportedMarshallingMethods.Pinning, - NativeTypePinnable: true, + NativeMarshallingType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(arrayMarshaller), + ValuePropertyType: valuePropertyType is not null ? ManagedTypeInfo.CreateTypeInfoForTypeSymbol(valuePropertyType) : null, + MarshallingFeatures: ~CustomMarshallingFeatures.ManagedTypePinning, UseDefaultMarshalling: true, ElementCountInfo: parsedCountInfo, - ElementType: elementType, + ElementType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(elementType), ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); return true; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs index fe488c608812e6..1fcacdb07a720e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeContext.cs @@ -61,7 +61,7 @@ public enum Stage GuaranteedUnmarshal } - public Stage CurrentStage { get; protected set; } = Stage.Invalid; + public Stage CurrentStage { get; set; } = Stage.Invalid; /// /// The stub emits code that runs in a single stack frame and the frame spans over the native context. @@ -88,7 +88,7 @@ public enum Stage /// public StubCodeContext? ParentContext { get; protected set; } - protected const string GeneratedNativeIdentifierSuffix = "_gen_native"; + public const string GeneratedNativeIdentifierSuffix = "_gen_native"; /// /// Get managed and native instance identifiers for the diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs index 5679a0235094fd..2732414adc42b2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; @@ -8,11 +9,14 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.Interop.StubCodeContext; namespace Microsoft.Interop { internal sealed class StubCodeGenerator : StubCodeContext { + private record struct BoundGenerator(TypePositionInfo TypeInfo, IMarshallingGenerator Generator); + public override bool SingleFrameSpansNativeContext => true; public override bool AdditionalTemporaryStateLivesAcrossStages => true; @@ -26,7 +30,7 @@ internal sealed class StubCodeGenerator : StubCodeContext /// Identifier for native return value /// /// Same as the managed identifier by default - public string ReturnNativeIdentifier { get; private set; } = ReturnIdentifier; + public string ReturnNativeIdentifier { get; } = ReturnIdentifier; private const string InvokeReturnIdentifier = "__invokeRetVal"; private const string LastErrorIdentifier = "__lastError"; @@ -35,40 +39,62 @@ internal sealed class StubCodeGenerator : StubCodeContext // Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics. private const int SuccessErrorCode = 0; - private readonly GeneratorDiagnostics diagnostics; private readonly AnalyzerConfigOptions options; - private readonly IMethodSymbol stubMethod; - private readonly DllImportStub.GeneratedDllImportData dllImportData; - private readonly IEnumerable paramsTypeInfo; - private readonly List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> paramMarshallers; - private readonly (TypePositionInfo TypeInfo, IMarshallingGenerator Generator) retMarshaller; - private readonly List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> sortedMarshallers; + private readonly GeneratedDllImportData dllImportData; + private readonly List paramMarshallers; + private readonly BoundGenerator retMarshaller; + private readonly List sortedMarshallers; + private readonly bool stubReturnsVoid; public StubCodeGenerator( - IMethodSymbol stubMethod, - DllImportStub.GeneratedDllImportData dllImportData, - IEnumerable paramsTypeInfo, - TypePositionInfo retTypeInfo, - GeneratorDiagnostics generatorDiagnostics, - AnalyzerConfigOptions options) + GeneratedDllImportData dllImportData, + IEnumerable argTypes, + AnalyzerConfigOptions options, + Action marshallingNotSupportedCallback) { - Debug.Assert(retTypeInfo.IsNativeReturnPosition); - - this.stubMethod = stubMethod; this.dllImportData = dllImportData; - this.paramsTypeInfo = paramsTypeInfo.ToList(); - this.diagnostics = generatorDiagnostics; this.options = options; - // Get marshallers for parameters - this.paramMarshallers = paramsTypeInfo.Select(p => CreateGenerator(p)).ToList(); + List allMarshallers = new(); + List paramMarshallers = new(); + bool foundNativeRetMarshaller = false; + bool foundManagedRetMarshaller = false; + BoundGenerator nativeRetMarshaller = new(new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance), new Forwarder()); + BoundGenerator managedRetMarshaller = new(new TypePositionInfo(SpecialTypeInfo.Void, NoMarshallingInfo.Instance), new Forwarder()); + + foreach (var argType in argTypes) + { + BoundGenerator generator = CreateGenerator(argType); + allMarshallers.Add(generator); + if (argType.IsManagedReturnPosition) + { + Debug.Assert(!foundManagedRetMarshaller); + managedRetMarshaller = generator; + foundManagedRetMarshaller = true; + } + if (argType.IsNativeReturnPosition) + { + Debug.Assert(!foundNativeRetMarshaller); + nativeRetMarshaller = generator; + foundNativeRetMarshaller = true; + } + if (!argType.IsManagedReturnPosition && !argType.IsNativeReturnPosition) + { + paramMarshallers.Add(generator); + } + } - // Get marshaller for return - this.retMarshaller = CreateGenerator(retTypeInfo); + this.stubReturnsVoid = managedRetMarshaller.TypeInfo.ManagedType == SpecialTypeInfo.Void; + if (!managedRetMarshaller.TypeInfo.IsNativeReturnPosition && !this.stubReturnsVoid) + { + // If the managed ret marshaller isn't the native ret marshaller, then the managed ret marshaller + // is a parameter. + paramMarshallers.Add(managedRetMarshaller); + } - List<(TypePositionInfo TypeInfo, IMarshallingGenerator Generator)> allMarshallers = new(this.paramMarshallers); - allMarshallers.Add(retMarshaller); + this.retMarshaller = nativeRetMarshaller; + this.paramMarshallers = paramMarshallers; // We are doing a topological sort of our marshallers to ensure that each parameter/return value's // dependencies are unmarshalled before their dependents. This comes up in the case of contiguous @@ -98,17 +124,10 @@ public StubCodeGenerator( static m => GetInfoDependencies(m.TypeInfo)) .ToList(); - (TypePositionInfo info, IMarshallingGenerator gen) CreateGenerator(TypePositionInfo p) + if (managedRetMarshaller.Generator.UsesNativeIdentifier(managedRetMarshaller.TypeInfo, this)) { - try - { - return (p, MarshallingGenerators.Create(p, this, options)); - } - catch (MarshallingNotSupportedException e) - { - this.diagnostics.ReportMarshallingNotSupported(this.stubMethod, p, e.NotSupportedDetails); - return (p, MarshallingGenerators.Forwarder); - } + // Update the native identifier for the return value + this.ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}"; } static IEnumerable GetInfoDependencies(TypePositionInfo info) @@ -132,6 +151,19 @@ static int GetInfoIndex(TypePositionInfo info) } return info.ManagedIndex; } + + BoundGenerator CreateGenerator(TypePositionInfo p) + { + try + { + return new BoundGenerator(p, MarshallingGenerators.Create(p, this, options)); + } + catch (MarshallingNotSupportedException e) + { + marshallingNotSupportedCallback(p, e); + return new BoundGenerator(p, MarshallingGenerators.Forwarder); + } + } } public override (string managed, string native) GetIdentifiers(TypePositionInfo info) @@ -164,17 +196,11 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo } } - public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) + public BlockSyntax GenerateBody(string methodName, AttributeListSyntax? forwardedAttributes) { - string dllImportName = stubMethod.Name + "__PInvoke__"; + string dllImportName = methodName + "__PInvoke__"; var setupStatements = new List(); - if (retMarshaller.Generator.UsesNativeIdentifier(retMarshaller.TypeInfo, this)) - { - // Update the native identifier for the return value - ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}"; - } - foreach (var marshaller in paramMarshallers) { TypePositionInfo info = marshaller.TypeInfo; @@ -197,8 +223,7 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) AppendVariableDeclations(setupStatements, info, marshaller.Generator); } - bool invokeReturnsVoid = retMarshaller.TypeInfo.ManagedType.SpecialType == SpecialType.System_Void; - bool stubReturnsVoid = stubMethod.ReturnsVoid; + bool invokeReturnsVoid = retMarshaller.TypeInfo.ManagedType == SpecialTypeInfo.Void; // Stub return is not the same as invoke return if (!stubReturnsVoid && !retMarshaller.TypeInfo.IsManagedReturnPosition) @@ -210,11 +235,6 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) Debug.Assert(paramMarshallers.Any() && paramMarshallers.Last().TypeInfo.IsManagedReturnPosition, "Expected stub return to be the last parameter for the invoke"); (TypePositionInfo stubRetTypeInfo, IMarshallingGenerator stubRetGenerator) = paramMarshallers.Last(); - if (stubRetGenerator.UsesNativeIdentifier(stubRetTypeInfo, this)) - { - // Update the native identifier for the return value - ReturnNativeIdentifier = $"{ReturnIdentifier}{GeneratedNativeIdentifierSuffix}"; - } // Declare variables for stub return value AppendVariableDeclations(setupStatements, stubRetTypeInfo, stubRetGenerator); @@ -303,7 +323,7 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) .WithAttributeLists( SingletonList(AttributeList( - SingletonSeparatedList(CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData()))))); + SingletonSeparatedList(CreateDllImportAttributeForTarget(GetTargetDllImportDataFromStubData(methodName)))))); if (retMarshaller.Generator is IAttributedReturnTypeMarshallingGenerator retGenerator) { @@ -313,7 +333,7 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) dllImport = dllImport.AddAttributeLists(returnAttribute.WithTarget(AttributeTargetSpecifier(Identifier("return")))); } } - + if (forwardedAttributes is not null) { dllImport = dllImport.AddAttributeLists(forwardedAttributes); @@ -334,7 +354,6 @@ void GenerateStatementsForStage(Stage stage, List statementsToU if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup)) { - // Handle setup and unmarshalling for return var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this); statementsToUpdate.AddRange(retStatements); } @@ -400,7 +419,6 @@ void GenerateStatementsForInvoke(List statementsToUpdate, Invoc } StatementSyntax invokeStatement; - // Assign to return value if necessary if (invokeReturnsVoid) { @@ -442,7 +460,6 @@ void GenerateStatementsForInvoke(List statementsToUpdate, Invoc invokeStatement = Block(clearLastError, invokeStatement, getLastError); } - // Nest invocation in fixed statements if (fixedStatements.Any()) { @@ -467,13 +484,13 @@ void GenerateStatementsForInvoke(List statementsToUpdate, Invoc private void AppendVariableDeclations(List statementsToUpdate, TypePositionInfo info, IMarshallingGenerator generator) { - var (managed, native) = GetIdentifiers(info); + var (managed, native) = this.GetIdentifiers(info); // Declare variable for return value if (info.IsManagedReturnPosition || info.IsNativeReturnPosition) { statementsToUpdate.Add(MarshallerHelpers.DeclareWithDefault( - info.ManagedType.AsTypeSyntax(), + info.ManagedType.Syntax, managed)); } @@ -486,8 +503,9 @@ private void AppendVariableDeclations(List statementsToUpdate, } } - private static AttributeSyntax CreateDllImportAttributeForTarget(DllImportStub.GeneratedDllImportData targetDllImportData) + private static AttributeSyntax CreateDllImportAttributeForTarget(GeneratedDllImportData targetDllImportData) { + Debug.Assert(targetDllImportData.EntryPoint is not null); var newAttributeArgs = new List { AttributeArgument(LiteralExpression( @@ -496,46 +514,46 @@ private static AttributeSyntax CreateDllImportAttributeForTarget(DllImportStub.G AttributeArgument( NameEquals(nameof(DllImportAttribute.EntryPoint)), null, - CreateStringExpressionSyntax(targetDllImportData.EntryPoint)) + CreateStringExpressionSyntax(targetDllImportData.EntryPoint!)) }; - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.BestFitMapping)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.BestFitMapping)) { var name = NameEquals(nameof(DllImportAttribute.BestFitMapping)); var value = CreateBoolExpressionSyntax(targetDllImportData.BestFitMapping); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CallingConvention)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.CallingConvention)) { var name = NameEquals(nameof(DllImportAttribute.CallingConvention)); var value = CreateEnumExpressionSyntax(targetDllImportData.CallingConvention); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.CharSet)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.CharSet)) { var name = NameEquals(nameof(DllImportAttribute.CharSet)); var value = CreateEnumExpressionSyntax(targetDllImportData.CharSet); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ExactSpelling)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.ExactSpelling)) { var name = NameEquals(nameof(DllImportAttribute.ExactSpelling)); var value = CreateBoolExpressionSyntax(targetDllImportData.ExactSpelling); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.PreserveSig)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.PreserveSig)) { var name = NameEquals(nameof(DllImportAttribute.PreserveSig)); var value = CreateBoolExpressionSyntax(targetDllImportData.PreserveSig); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.SetLastError)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.SetLastError)) { var name = NameEquals(nameof(DllImportAttribute.SetLastError)); var value = CreateBoolExpressionSyntax(targetDllImportData.SetLastError); newAttributeArgs.Add(AttributeArgument(name, null, value)); } - if (targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.ThrowOnUnmappableChar)) + if (targetDllImportData.IsUserDefined.HasFlag(DllImportMember.ThrowOnUnmappableChar)) { var name = NameEquals(nameof(DllImportAttribute.ThrowOnUnmappableChar)); var value = CreateBoolExpressionSyntax(targetDllImportData.ThrowOnUnmappableChar); @@ -571,31 +589,22 @@ static ExpressionSyntax CreateEnumExpressionSyntax(T value) where T : Enum } } - DllImportStub.GeneratedDllImportData GetTargetDllImportDataFromStubData() + GeneratedDllImportData GetTargetDllImportDataFromStubData(string methodName) { - DllImportStub.DllImportMember membersToForward = DllImportStub.DllImportMember.All + DllImportMember membersToForward = DllImportMember.All // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.preservesig // If PreserveSig=false (default is true), the P/Invoke stub checks/converts a returned HRESULT to an exception. - & ~DllImportStub.DllImportMember.PreserveSig + & ~DllImportMember.PreserveSig // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.dllimportattribute.setlasterror // If SetLastError=true (default is false), the P/Invoke stub gets/caches the last error after invoking the native function. - & ~DllImportStub.DllImportMember.SetLastError; + & ~DllImportMember.SetLastError; if (options.GenerateForwarders()) { - membersToForward = DllImportStub.DllImportMember.All; + membersToForward = DllImportMember.All; } - var targetDllImportData = new DllImportStub.GeneratedDllImportData + var targetDllImportData = dllImportData with { - CharSet = dllImportData.CharSet, - BestFitMapping = dllImportData.BestFitMapping, - CallingConvention = dllImportData.CallingConvention, - EntryPoint = dllImportData.EntryPoint, - ModuleName = dllImportData.ModuleName, - ExactSpelling = dllImportData.ExactSpelling, - SetLastError = dllImportData.SetLastError, - PreserveSig = dllImportData.PreserveSig, - ThrowOnUnmappableChar = dllImportData.ThrowOnUnmappableChar, IsUserDefined = dllImportData.IsUserDefined & membersToForward }; @@ -604,9 +613,9 @@ DllImportStub.GeneratedDllImportData GetTargetDllImportDataFromStubData() // // N.B. The export discovery logic is identical regardless of where // the name is defined (i.e. method name vs EntryPoint property). - if (!targetDllImportData.IsUserDefined.HasFlag(DllImportStub.DllImportMember.EntryPoint)) + if (!targetDllImportData.IsUserDefined.HasFlag(DllImportMember.EntryPoint)) { - targetDllImportData.EntryPoint = stubMethod.Name; + targetDllImportData = targetDllImportData with { EntryPoint = methodName }; } return targetDllImportData; diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs index 0f011d9f0846a0..2cbbcb87beccb2 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypePositionInfo.cs @@ -40,27 +40,15 @@ internal enum ByValueContentsMarshalKind /// /// Positional type information involved in unmanaged/managed scenarios. /// - internal sealed record TypePositionInfo + internal sealed record TypePositionInfo(ManagedTypeInfo ManagedType, MarshallingInfo MarshallingAttributeInfo) { public const int UnsetIndex = int.MinValue; public const int ReturnIndex = UnsetIndex + 1; -// We don't need the warnings around not setting the various -// non-nullable fields/properties on this type in the constructor -// since we always use a property initializer. -#pragma warning disable 8618 - private TypePositionInfo() - { - this.ManagedIndex = UnsetIndex; - this.NativeIndex = UnsetIndex; - } -#pragma warning restore - - public string InstanceIdentifier { get; init; } - public ITypeSymbol ManagedType { get; init; } + public string InstanceIdentifier { get; init; } = string.Empty; - public RefKind RefKind { get; init; } - public SyntaxKind RefKindSyntax { get; init; } + public RefKind RefKind { get; init; } = RefKind.None; + public SyntaxKind RefKindSyntax { get; init; } = SyntaxKind.None; public bool IsByRef => RefKind != RefKind.None; @@ -69,40 +57,22 @@ private TypePositionInfo() public bool IsManagedReturnPosition { get => this.ManagedIndex == ReturnIndex; } public bool IsNativeReturnPosition { get => this.NativeIndex == ReturnIndex; } - public int ManagedIndex { get; init; } - public int NativeIndex { get; init; } - - public MarshallingInfo MarshallingAttributeInfo { get; init; } + public int ManagedIndex { get; init; } = UnsetIndex; + public int NativeIndex { get; init; } = UnsetIndex; public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingInfo marshallingInfo, Compilation compilation) { - var typeInfo = new TypePositionInfo() + var typeInfo = new TypePositionInfo(ManagedTypeInfo.CreateTypeInfoForTypeSymbol(paramSymbol.Type), marshallingInfo) { - ManagedType = paramSymbol.Type, InstanceIdentifier = ParseToken(paramSymbol.Name).IsReservedKeyword() ? $"@{paramSymbol.Name}" : paramSymbol.Name, RefKind = paramSymbol.RefKind, RefKindSyntax = RefKindToSyntax(paramSymbol.RefKind), - MarshallingAttributeInfo = marshallingInfo, ByValueContentsMarshalKind = GetByValueContentsMarshalKind(paramSymbol.GetAttributes(), compilation) }; return typeInfo; } - public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo marshallingInfo, string identifier = "") - { - var typeInfo = new TypePositionInfo() - { - ManagedType = type, - InstanceIdentifier = identifier, - RefKind = RefKind.None, - RefKindSyntax = SyntaxKind.None, - MarshallingAttributeInfo = marshallingInfo - }; - - return typeInfo; - } - private static ByValueContentsMarshalKind GetByValueContentsMarshalKind(IEnumerable attributes, Compilation compilation) { INamedTypeSymbol outAttributeType = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_OutAttribute)!; diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs index a37aa5708efadd..e9b3b8261a95f7 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/TypeSymbolExtensions.cs @@ -163,9 +163,9 @@ public static TypeSyntax AsTypeSyntax(this ITypeSymbol type) return SyntaxFactory.ParseTypeName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); } - public static bool IsIntegralType(this ITypeSymbol type) + public static bool IsIntegralType(this SpecialType type) { - return type.SpecialType switch + return type switch { SpecialType.System_SByte or SpecialType.System_Byte diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs new file mode 100644 index 00000000000000..f33ae8b9564fd1 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/UnreachableException.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Interop +{ + /// + /// An exception that should be thrown on code-paths that are unreachable. + /// + internal class UnreachableException : Exception + { + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj index 7af1a40c40bcd1..e19d9b4e5d9bf9 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj +++ b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/Ancillary.Interop.csproj @@ -3,7 +3,6 @@ Microsoft.Interop.Ancillary net6.0 - 8.0 System.Runtime.InteropServices enable true diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs new file mode 100644 index 00000000000000..e37e6fdd7ab439 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/IncrementalGenerationTests.cs @@ -0,0 +1,203 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; +using static Microsoft.Interop.DllImportGenerator; + +namespace DllImportGenerator.UnitTests +{ + public class IncrementalGenerationTests + { + public const string RequiresIncrementalSyntaxTreeModifySupport = "The GeneratorDriver treats all SyntaxTree replace operations on a Compilation as an Add/Remove operation instead of a Modify operation" + + ", so all cached results based on that input are thrown out. As a result, we cannot validate that unrelated changes within the same SyntaxTree do not cause regeneration."; + + [Fact] + public async Task AddingNewUnrelatedType_DoesNotRegenerateSource() + { + string source = CodeSnippets.BasicParametersAndModifiers(); + + Compilation comp1 = await TestUtils.CreateCompilation(source); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new IIncrementalGenerator[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + Compilation comp2 = comp1.AddSyntaxTrees(CSharpSyntaxTree.ParseText("struct Foo {}", new CSharpParseOptions(LanguageVersion.Preview))); + driver.RunGenerators(comp2); + + Assert.Collection(generator.IncrementalTracker.ExecutedSteps, + step => + { + Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step); + }); + } + + [Fact(Skip = RequiresIncrementalSyntaxTreeModifySupport)] + public async Task AppendingUnrelatedSource_DoesNotRegenerateSource() + { + string source = $"namespace NS{{{CodeSnippets.BasicParametersAndModifiers()}}}"; + + SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)); + + Compilation comp1 = await TestUtils.CreateCompilation(new[] { syntaxTree }); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + SyntaxTree newTree = syntaxTree.WithRootAndOptions(syntaxTree.GetCompilationUnitRoot().AddMembers(SyntaxFactory.ParseMemberDeclaration("struct Foo {}")!), syntaxTree.Options); + + Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), newTree); + driver.RunGenerators(comp2); + + Assert.Collection(generator.IncrementalTracker.ExecutedSteps, + step => + { + Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step); + }); + } + + [Fact] + public async Task AddingFileWithNewGeneratedDllImport_DoesNotRegenerateOriginalMethod() + { + string source = CodeSnippets.BasicParametersAndModifiers(); + + Compilation comp1 = await TestUtils.CreateCompilation(source); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + Compilation comp2 = comp1.AddSyntaxTrees(CSharpSyntaxTree.ParseText(CodeSnippets.BasicParametersAndModifiers(), new CSharpParseOptions(LanguageVersion.Preview))); + driver.RunGenerators(comp2); + + Assert.Equal(2, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.CalculateStubInformation)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.GenerateSingleStub)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.NormalizeWhitespace)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.ConcatenateStubs)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.OutputSourceFile)); + } + + [Fact] + public async Task ReplacingFileWithNewGeneratedDllImport_DoesNotRegenerateStubsInOtherFiles() + { + string source = CodeSnippets.BasicParametersAndModifiers(); + + Compilation comp1 = await TestUtils.CreateCompilation(new string[] { CodeSnippets.BasicParametersAndModifiers(), CodeSnippets.BasicParametersAndModifiers() }); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), CSharpSyntaxTree.ParseText(CodeSnippets.BasicParametersAndModifiers(), new CSharpParseOptions(LanguageVersion.Preview))); + driver.RunGenerators(comp2); + + Assert.Equal(2, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.CalculateStubInformation)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.GenerateSingleStub)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.NormalizeWhitespace)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.ConcatenateStubs)); + Assert.Equal(1, generator.IncrementalTracker.ExecutedSteps.Count(s => s.Step == IncrementalityTracker.StepName.OutputSourceFile)); + } + + [Fact] + public async Task ChangingMarshallingStrategy_RegeneratesStub() + { + string stubSource = CodeSnippets.BasicParametersAndModifiers("CustomType"); + + string customTypeImpl1 = "struct CustomType { System.IntPtr handle; }"; + + string customTypeImpl2 = "class CustomType : Microsoft.Win32.SafeHandles.SafeHandleZeroOrMinusOneIsInvalid { public CustomType():base(true){} protected override bool ReleaseHandle(){return true;} }"; + + + Compilation comp1 = await TestUtils.CreateCompilation(stubSource); + + SyntaxTree customTypeImpl1Tree = CSharpSyntaxTree.ParseText(customTypeImpl1, new CSharpParseOptions(LanguageVersion.Preview)); + comp1 = comp1.AddSyntaxTrees(customTypeImpl1Tree); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + Compilation comp2 = comp1.ReplaceSyntaxTree(customTypeImpl1Tree, CSharpSyntaxTree.ParseText(customTypeImpl2, new CSharpParseOptions(LanguageVersion.Preview))); + driver.RunGenerators(comp2); + + Assert.Collection(generator.IncrementalTracker.ExecutedSteps, + step => + { + Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step); + }, + step => + { + Assert.Equal(IncrementalityTracker.StepName.GenerateSingleStub, step.Step); + }, + step => + { + Assert.Equal(IncrementalityTracker.StepName.NormalizeWhitespace, step.Step); + }, + step => + { + Assert.Equal(IncrementalityTracker.StepName.ConcatenateStubs, step.Step); + }, + step => + { + Assert.Equal(IncrementalityTracker.StepName.OutputSourceFile, step.Step); + }); + } + + [Fact(Skip = RequiresIncrementalSyntaxTreeModifySupport)] + public async Task ChangingMarshallingAttributes_SameStrategy_DoesNotRegenerate() + { + string source = CodeSnippets.BasicParametersAndModifiers(); + + SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)); + + Compilation comp1 = await TestUtils.CreateCompilation(new[] { syntaxTree }); + + Microsoft.Interop.DllImportGenerator generator = new(); + GeneratorDriver driver = TestUtils.CreateDriver(comp1, null, new[] { generator }); + + driver = driver.RunGenerators(comp1); + + generator.IncrementalTracker = new IncrementalityTracker(); + + SyntaxTree newTree = syntaxTree.WithRootAndOptions( + syntaxTree.GetCompilationUnitRoot().AddMembers( + SyntaxFactory.ParseMemberDeclaration( + CodeSnippets.MarshalAsParametersAndModifiers(System.Runtime.InteropServices.UnmanagedType.Bool))!), + syntaxTree.Options); + + Compilation comp2 = comp1.ReplaceSyntaxTree(comp1.SyntaxTrees.First(), newTree); + driver.RunGenerators(comp2); + + Assert.Collection(generator.IncrementalTracker.ExecutedSteps, + step => + { + Assert.Equal(IncrementalityTracker.StepName.CalculateStubInformation, step.Step); + }, + step => + { + Assert.Equal(IncrementalityTracker.StepName.GenerateSingleStub, step.Step); + }); + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs index fd49c82ccf9142..db7b0f3feccf8c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.UnitTests/TestUtils.cs @@ -45,14 +45,26 @@ public static void AssertPreSourceGeneratorCompilation(Compilation comp) /// Output type /// Whether or not use of the unsafe keyword should be allowed /// The resulting compilation - public static async Task CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable? preprocessorSymbols = null) + public static Task CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable? preprocessorSymbols = null) { - var (mdRefs, ancillary) = GetReferenceAssemblies(); + return CreateCompilation(new[] { source }, outputKind, allowUnsafe, preprocessorSymbols); + } - return CSharpCompilation.Create("compilation", - new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols)) }, - (await mdRefs.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)).Add(ancillary), - new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe)); + /// + /// Create a compilation given sources + /// + /// Sources to compile + /// Output type + /// Whether or not use of the unsafe keyword should be allowed + /// The resulting compilation + public static Task CreateCompilation(string[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable? preprocessorSymbols = null) + { + return CreateCompilation( + sources.Select(source => + CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols))).ToArray(), + outputKind, + allowUnsafe, + preprocessorSymbols); } /// @@ -62,13 +74,12 @@ public static async Task CreateCompilation(string source, OutputKin /// Output type /// Whether or not use of the unsafe keyword should be allowed /// The resulting compilation - public static async Task CreateCompilation(string[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable? preprocessorSymbols = null) + public static async Task CreateCompilation(SyntaxTree[] sources, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable? preprocessorSymbols = null) { var (mdRefs, ancillary) = GetReferenceAssemblies(); return CSharpCompilation.Create("compilation", - sources.Select(source => - CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols))).ToArray(), + sources, (await mdRefs.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)).Add(ancillary), new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe)); } @@ -81,10 +92,23 @@ public static async Task CreateCompilation(string[] sources, Output /// Output type /// Whether or not use of the unsafe keyword should be allowed /// The resulting compilation - public static async Task CreateCompilationWithReferenceAssemblies(string source, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true) + public static Task CreateCompilationWithReferenceAssemblies(string source, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true) + { + return CreateCompilationWithReferenceAssemblies(new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)) }, referenceAssemblies, outputKind, allowUnsafe); + } + + /// + /// Create a compilation given source and reference assemblies + /// + /// Source to compile + /// Reference assemblies to include + /// Output type + /// Whether or not use of the unsafe keyword should be allowed + /// The resulting compilation + public static async Task CreateCompilationWithReferenceAssemblies(SyntaxTree[] sources, ReferenceAssemblies referenceAssemblies, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true) { return CSharpCompilation.Create("compilation", - new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)) }, + sources, (await referenceAssemblies.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)), new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe)); } @@ -96,7 +120,7 @@ public static (ReferenceAssemblies, MetadataReference) GetReferenceAssemblies() "net6.0", new PackageIdentity( "Microsoft.NETCore.App.Ref", - "6.0.0-preview.6.21317.4"), + "6.0.0-preview.7.21377.19"), Path.Combine("ref", "net6.0")) .WithNuGetConfigFilePath(Path.Combine(Path.GetDirectoryName(typeof(TestUtils).Assembly.Location)!, "NuGet.config")); @@ -114,7 +138,7 @@ public static (ReferenceAssemblies, MetadataReference) GetReferenceAssemblies() /// Resulting diagnostics /// Source generator instances /// The resulting compilation - public static Compilation RunGenerators(Compilation comp, out ImmutableArray diagnostics, params ISourceGenerator[] generators) + public static Compilation RunGenerators(Compilation comp, out ImmutableArray diagnostics, params IIncrementalGenerator[] generators) { CreateDriver(comp, null, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics); return d; @@ -127,15 +151,15 @@ public static Compilation RunGenerators(Compilation comp, out ImmutableArrayResulting diagnostics /// Source generator instances /// The resulting compilation - public static Compilation RunGenerators(Compilation comp, AnalyzerConfigOptionsProvider options, out ImmutableArray diagnostics, params ISourceGenerator[] generators) + public static Compilation RunGenerators(Compilation comp, AnalyzerConfigOptionsProvider options, out ImmutableArray diagnostics, params IIncrementalGenerator[] generators) { CreateDriver(comp, options, generators).RunGeneratorsAndUpdateCompilation(comp, out var d, out diagnostics); return d; } - private static GeneratorDriver CreateDriver(Compilation c, AnalyzerConfigOptionsProvider? options, ISourceGenerator[] generators) + public static GeneratorDriver CreateDriver(Compilation c, AnalyzerConfigOptionsProvider? options, IIncrementalGenerator[] generators) => CSharpGeneratorDriver.Create( - ImmutableArray.Create(generators), + ImmutableArray.Create(generators.Select(gen => gen.AsSourceGenerator()).ToArray()), parseOptions: (CSharpParseOptions)c.SyntaxTrees.First().Options, optionsProvider: options); }