From 71796dda7a6c81c20fe7815a6894892b30475e06 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 26 May 2021 16:24:22 -0700 Subject: [PATCH 01/30] Add marshaller for blittable collections. --- ...onBlittableElementsMarshallingGenerator.cs | 105 ++++++++++++++++++ .../Marshalling/CustomNativeTypeMarshaller.cs | 34 +++++- 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs new file mode 100644 index 000000000000..88c9cf311418 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -0,0 +1,105 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + class ContiguousCollectionBlittableElementsMarshallingGenerator : CustomNativeTypeMarshaller + { + private readonly ITypeSymbol elementType; + private readonly ExpressionSyntax numElementsExpression; + + public ContiguousCollectionBlittableElementsMarshallingGenerator( + NativeMarshallingAttributeInfo marshallingInfo, + ITypeSymbol elementType, + ExpressionSyntax numElementsExpression) + :base(marshallingInfo) + { + this.elementType = elementType; + this.numElementsExpression = numElementsExpression; + } + + protected override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + yield return Argument(SizeOfExpression(elementType.AsTypeSyntax())); + } + + protected override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = GetMarshallerIdentifier(info, context); + // .ManagedValues.CopyTo(MemoryMarshal.Cast>( GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = GetMarshallerIdentifier(info, context); + // MemoryMarshal.Cast>(.ManagedValues); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("MemoryMarshal"), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new [] + { + PredefinedType(Token(SyntaxKind.ByteKeyword)), + elementType.AsTypeSyntax() + }))))) + .AddArgumentListArguments( + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName("NativeValueStorage")))), + IdentifierName("CopyTo"))) + .AddArgumentListArguments( + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName("ManagedValues"))))); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index c2f70b702b8f..456ce82f6022 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -38,6 +38,12 @@ public CustomNativeTypeMarshaller(GeneratedNativeMarshallingAttributeInfo marsha _marshalerTypePinnable = false; } + protected string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) + { + var (_, nativeIdentifier) = context.GetIdentifiers(info); + return _useValueProperty ? nativeIdentifier + MarshalerLocalSuffix : nativeIdentifier; + } + public TypeSyntax AsNativeType(TypePositionInfo info) { return _nativeTypeSyntax; @@ -74,7 +80,7 @@ public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); - string marshalerIdentifier = _useValueProperty ? nativeIdentifier + MarshalerLocalSuffix : nativeIdentifier; + string marshalerIdentifier = GetMarshallerIdentifier(info, context); if (!info.IsManagedReturnPosition && !info.IsByRef && context.PinningSupported @@ -160,6 +166,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont _nativeLocalTypeSyntax, IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName))) }))))); + arguments.AddRange(GenerateAdditionalNativeTypeConstructorArguments(info, context)); } // = new <_nativeLocalType>(); @@ -170,6 +177,11 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont ObjectCreationExpression(_nativeLocalTypeSyntax) .WithArgumentList(ArgumentList(SeparatedList(arguments))))); + foreach (var statement in GenerateIntermediateMarshallingStatements(info, context)) + { + yield return statement; + } + bool skipValueProperty = _marshalerTypePinnable && (!info.IsByRef || info.RefKind == RefKind.In); if (_useValueProperty && !skipValueProperty) @@ -219,6 +231,11 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont IdentifierName(nativeIdentifier))); } + foreach (var statement in GenerateIntermediateUnmarshallingStatements(info, context)) + { + yield return statement; + } + // = .ToManaged(); yield return ExpressionStatement( AssignmentExpression( @@ -247,6 +264,21 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } } + protected virtual IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + protected virtual IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + protected virtual IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { if (info.IsManagedReturnPosition || info.IsByRef && info.RefKind != RefKind.In) From f4fe65659601ddf01803262def50595b57498cc3 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 27 May 2021 10:43:35 -0700 Subject: [PATCH 02/30] Implement contiguous non-blittable collection marshalling. --- ...CollectionElementMarshallingCodeContext.cs | 73 +++++++++++ ...onBlittableElementsMarshallingGenerator.cs | 12 ++ ...onBlittableElementsMarshallingGenerator.cs | 113 ++++++++++++++++++ .../Marshalling/CustomNativeTypeMarshaller.cs | 12 +- .../Marshalling/MarshallerHelpers.cs | 11 ++ 5 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs diff --git a/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs new file mode 100644 index 000000000000..e203391abf3d --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs @@ -0,0 +1,73 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + internal sealed class ContiguousCollectionElementMarshallingCodeContext : StubCodeContext + { + private readonly string indexerIdentifier; + private readonly string nativeSpanIdentifier; + private readonly StubCodeContext parentContext; + + public override bool PinningSupported => false; + + public override bool StackSpaceUsable => false; + + /// + /// Additional variables other than the {managedIdentifier} and {nativeIdentifier} variables + /// can be added to the stub to track additional state for the marshaller in the stub. + /// + /// + /// Currently, collection scenarios do not support declaring additional temporary variables to support + /// marshalling. This can be accomplished in the future with some additional infrastructure to support + /// declaring additional arrays in the stub to support the temporary state. + /// + public override bool CanUseAdditionalTemporaryState => false; + + /// + /// Create a for marshalling elements of an collection. + /// + /// The current marshalling stage. + /// The indexer in the loop to get the element to marshal from the collection. + /// The identifier of the native value storage cast to the target element type. + /// The parent context. + public ContiguousCollectionElementMarshallingCodeContext( + Stage currentStage, + string indexerIdentifier, + string nativeSpanIdentifier, + StubCodeContext parentContext) + { + CurrentStage = currentStage; + this.indexerIdentifier = indexerIdentifier; + this.nativeSpanIdentifier = nativeSpanIdentifier; + this.parentContext = parentContext; + } + + /// + /// Get managed and native instance identifiers for the + /// + /// Object for which to get identifiers + /// Managed and native identifiers + public override (string managed, string native) GetIdentifiers(TypePositionInfo info) + { + var (managed, _) = parentContext.GetIdentifiers(info); + return ( + $"{managed}{CustomNativeTypeMarshaller.MarshalerLocalSuffix}.ManagedValues[{indexerIdentifier}]", + $"{nativeSpanIdentifier}[{indexerIdentifier}]" + ); + } + + public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index) + { + // We don't have parameters to look at when we're in the middle of marshalling an array. + return null; + } + } +} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index 88c9cf311418..be4fc800197f 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -65,6 +65,18 @@ protected override IEnumerable GenerateIntermediateMarshallingS IdentifierName("NativeValueStorage"))))))); } + protected override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = GetMarshallerIdentifier(info, context); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName("SetUnmarshalledCollectionLength"))) + .AddArgumentListArguments(Argument(numElementsExpression))); + } + protected override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs new file mode 100644 index 000000000000..e32ea09f196e --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -0,0 +1,113 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + class ContiguousCollectionNonBlittableElementsMarshallingGenerator : CustomNativeTypeMarshaller + { + private const string IndexerIdentifier = "__i"; + private readonly IMarshallingGenerator elementMarshaller; + private readonly TypePositionInfo elementInfo; + private readonly ExpressionSyntax numElementsExpression; + + public ContiguousCollectionNonBlittableElementsMarshallingGenerator( + NativeMarshallingAttributeInfo marshallingInfo, + IMarshallingGenerator elementMarshaller, + TypePositionInfo elementInfo, + ExpressionSyntax numElementsExpression) + :base(marshallingInfo) + { + this.elementMarshaller = elementMarshaller; + this.elementInfo = elementInfo; + this.numElementsExpression = numElementsExpression; + } + + protected override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + yield return Argument(SizeOfExpression(elementMarshaller.AsNativeType(elementInfo))); + } + + private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext context) + { + return context.GetIdentifiers(info).managed + "__nativeSpan"; + } + + private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(string nativeSpanIdentifier) + { + return LocalDeclarationStatement(VariableDeclaration( + GenericName( + Identifier(TypeNames.System_Span), + TypeArgumentList( + SingletonSeparatedList(PredefinedType(Token(SyntaxKind.ByteKeyword)))) + ), + SingletonSeparatedList( + VariableDeclarator(Identifier(nativeSpanIdentifier)) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("MemoryMarshal"), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new [] + { + PredefinedType(Token(SyntaxKind.ByteKeyword)), + elementMarshaller.AsNativeType(elementInfo) + })))))))))); + } + + internal StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) + { + string marshalerIdentifier = GetMarshallerIdentifier(info, context); + string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); + var elementSubContext = new ContiguousCollectionElementMarshallingCodeContext( + context.CurrentStage, + IndexerIdentifier, + nativeSpanIdentifier, + context); + + string collectionIdentifierForLength = useManagedSpanForLength + ? $"{marshalerIdentifier}.ManagedValues" + : nativeSpanIdentifier; + + // Iterate through the elements of the native collection to unmarshal them + return Block( + GenerateNativeSpanDeclaration(GetNativeSpanIdentifier(info, context)), + MarshallerHelpers.GetForLoop(collectionIdentifierForLength, IndexerIdentifier) + .WithStatement(Block( + List(elementMarshaller.Generate( + elementInfo, + elementSubContext))))); + } + + protected override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); + } + + protected override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = GetMarshallerIdentifier(info, context); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName("SetUnmarshalledCollectionLength"))) + .AddArgumentListArguments(Argument(numElementsExpression))); + } + + protected override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index 456ce82f6022..50faf74d1c55 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -10,7 +10,7 @@ namespace Microsoft.Interop { class CustomNativeTypeMarshaller : IMarshallingGenerator { - private const string MarshalerLocalSuffix = "__marshaler"; + internal const string MarshalerLocalSuffix = "__marshaler"; private readonly TypeSyntax _nativeTypeSyntax; private readonly TypeSyntax _nativeLocalTypeSyntax; private readonly SupportedMarshallingMethods _marshallingMethods; @@ -219,6 +219,11 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont case StubCodeContext.Stage.Unmarshal: if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) { + foreach (var statement in GeneratePreUnmarshallingStatements(info, context)) + { + yield return statement; + } + if (_useValueProperty) { // .Value = ; @@ -274,6 +279,11 @@ protected virtual IEnumerable GenerateIntermediateMarshallingSt return Array.Empty(); } + protected virtual IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + protected virtual IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { return Array.Empty(); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs index fd2b8c6200b3..9d293fa338a4 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs @@ -73,6 +73,17 @@ public static RefKind GetRefKindForByValueContentsKind(this ByValueContentsMarsh }; } + public static TypeSyntax GetCompatibleGenericTypeParameterSyntax(this TypeSyntax type) + { + TypeSyntax spanElementTypeSyntax = type; + if (spanElementTypeSyntax is PointerTypeSyntax) + { + // Pointers cannot be passed to generics, so use IntPtr for this case. + spanElementTypeSyntax = ParseTypeName("System.IntPtr"); + } + return spanElementTypeSyntax; + } + public static class StringMarshaller { public static ExpressionSyntax AllocationExpression(CharEncoding encoding, string managedIdentifier) From a0835cc41e206041a6f6cb4276e6b2c64f5d21da Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 27 May 2021 12:05:33 -0700 Subject: [PATCH 03/30] Implement array marshaller in source following the collections model. Implement an array marshaller in generator that wraps the collection marshaller model to support the special pinning and [Out] semantics required for arrays. --- .../Ancillary.Interop/ArrayMarshaller.cs | 105 +++++++++++ .../Marshalling/ArrayMarshaller.cs | 175 ++++++++++++++++++ ...onBlittableElementsMarshallingGenerator.cs | 8 +- ...onBlittableElementsMarshallingGenerator.cs | 10 +- .../Marshalling/CustomNativeTypeMarshaller.cs | 16 +- 5 files changed, 297 insertions(+), 17 deletions(-) create mode 100644 DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs diff --git a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs new file mode 100644 index 000000000000..5ff669bd8a95 --- /dev/null +++ b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs @@ -0,0 +1,105 @@ + +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace System.Runtime.InteropServices.GeneratedMarshalling +{ + public unsafe ref struct ArrayMarshaller + { + private T[]? managedArray; + private readonly int sizeOfNativeElement; + private IntPtr allocatedMemory; + + public ArrayMarshaller(T[]? managed, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedArray = null; + NativeValueStorage = default; + return; + } + managedArray = managed; + this.sizeOfNativeElement = sizeOfNativeElement; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + } + + public ArrayMarshaller(T[]? managed, Span stackSpace, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedArray = null; + NativeValueStorage = default; + return; + } + managedArray = managed; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); + if (spaceToAllocate < stackSpace.Length) + { + NativeValueStorage = stackSpace.Slice(spaceToAllocate); + } + else + { + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + } + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of array parameters doesn't + /// blow the stack since this is a new optimization in the code-generated interop. + /// + public const int StackBufferSize = 0x200; + + public Span ManagedValues => managedArray; + + public Span NativeValueStorage { get; private set; } + + public ref byte GetPinnableReference() => ref MemoryMarshal.GetReference(NativeValueStorage); + + public void SetUnmarshalledCollectionLength(int length) + { + managedArray = new T[length]; + } + + public byte* Value + { + get + { + Debug.Assert(managedArray is null || allocatedMemory != IntPtr.Zero); + return (byte*)allocatedMemory; + } + set + { + if (value == null) + { + managedArray = null; + NativeValueStorage = default; + } + else + { + NativeValueStorage = new Span(value, (managedArray?.Length ?? 0) * sizeOfNativeElement); + } + + } + } + + public T[]? ToManaged() => managedArray; + + public void FreeNative() + { + if (allocatedMemory != IntPtr.Zero) + { + Marshal.FreeCoTaskMem(allocatedMemory); + } + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs new file mode 100644 index 000000000000..d14dcdf1411b --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -0,0 +1,175 @@ +using System; +using System.Collections.Generic; + +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + internal class ArrayMarshaller : CustomNativeTypeMarshaller + { + private CustomNativeTypeMarshaller innerCollectionMarshaller; + + private bool blittable; + + public ArrayMarshaller( + ContiguousCollectionBlittableElementsMarshallingGenerator innerCollectionMarshaller, + NativeMarshallingAttributeInfo marshallingInfo) + : base(marshallingInfo) + { + this.innerCollectionMarshaller = innerCollectionMarshaller; + blittable = true; + } + + public ArrayMarshaller( + ContiguousCollectionNonBlittableElementsMarshallingGenerator innerCollectionMarshaller, + NativeMarshallingAttributeInfo marshallingInfo) + : base(marshallingInfo) + { + this.innerCollectionMarshaller = innerCollectionMarshaller; + blittable = true; + } + + private bool UseCustomPinningPath(TypePositionInfo info, StubCodeContext context) + { + return blittable && !info.IsByRef && !info.IsManagedReturnPosition && context.PinningSupported; + } + + public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + { + if (UseCustomPinningPath(info, context)) + { + return GenerateCustomPinning(); + } + + if (context.CurrentStage == StubCodeContext.Stage.Unmarshal + && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + { + return GenerateByValueOutUnmarshalling(); + } + + return innerCollectionMarshaller.Generate(info, context); + + IEnumerable GenerateCustomPinning() + { + var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); + string byRefIdentifier = $"__byref_{managedIdentifer}"; + TypeSyntax arrayElementType = ((IArrayTypeSymbol)info.ManagedType).ElementType.AsTypeSyntax(); + if (context.CurrentStage == StubCodeContext.Stage.Marshal) + { + // [COMPAT] We use explicit byref calculations here instead of just using a fixed statement + // since a fixed statement converts a zero-length array to a null pointer. + // Many native APIs, such as GDI+, ICU, etc. validate that an array parameter is non-null + // even when the passed in array length is zero. To avoid breaking customers that want to move + // to source-generated interop in subtle ways, we explicitly pass a reference to the 0-th element + // of an array as long as it is non-null, matching the behavior of the built-in interop system + // for single-dimensional zero-based arrays. + + // ref = == null ? ref *(); + var nullRef = + PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(arrayElementType), + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))); + + var getArrayDataReference = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + IdentifierName("GetArrayDataReference")), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifer))))); + + yield return LocalDeclarationStatement( + VariableDeclaration( + RefType(arrayElementType)) + .WithVariables(SingletonSeparatedList( + VariableDeclarator(Identifier(byRefIdentifier)) + .WithInitializer(EqualsValueClause( + RefExpression(ParenthesizedExpression( + ConditionalExpression( + BinaryExpression( + SyntaxKind.EqualsExpression, + IdentifierName(managedIdentifer), + LiteralExpression( + SyntaxKind.NullLiteralExpression)), + RefExpression(nullRef), + RefExpression(getArrayDataReference))))))))); + } + if (context.CurrentStage == StubCodeContext.Stage.Pin) + { + // fixed ( = &) + yield return FixedStatement( + VariableDeclaration(AsNativeType(info), SingletonSeparatedList( + VariableDeclarator(nativeIdentifier) + .WithInitializer(EqualsValueClause( + PrefixUnaryExpression(SyntaxKind.AddressOfExpression, + IdentifierName(byRefIdentifier)))))), + EmptyStatement()); + } + yield break; + } + + IEnumerable GenerateByValueOutUnmarshalling() + { + var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); + // For [Out] by value unmarshalling, we emit custom code that only assigns the + // Value property and copy the elements. + // We do not call SetUnmarshalledCollectionLength since that creates a new + // array, and we want to fill the original one. + string marshalerIdentifier = innerCollectionMarshaller.GetMarshallerIdentifier(info, context); + // .Value = ; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), + IdentifierName(nativeIdentifier))); + + foreach (var statement in innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context)) + { + yield return statement; + } + } + } + + public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return innerCollectionMarshaller.GenerateAdditionalNativeTypeConstructorArguments(info, context); + } + + public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + if (info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + { + // Don't marshal contents of an array when it is marshalled by value [Out]. + return Array.Empty(); + } + return innerCollectionMarshaller.GenerateIntermediateMarshallingStatements(info, context); + } + + public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + return innerCollectionMarshaller.GeneratePreUnmarshallingStatements(info, context); + } + + public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + { + return innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context); + } + + public override bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) + { + return !(blittable && context.PinningSupported) && marshalKind.HasFlag(ByValueContentsMarshalKind.Out); + } + + public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return !UseCustomPinningPath(info, context) && innerCollectionMarshaller.UsesNativeIdentifier(info, context); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index be4fc800197f..1e26bc0f433e 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -23,12 +23,12 @@ public ContiguousCollectionBlittableElementsMarshallingGenerator( this.numElementsExpression = numElementsExpression; } - protected override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { yield return Argument(SizeOfExpression(elementType.AsTypeSyntax())); } - protected override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); // .ManagedValues.CopyTo(MemoryMarshal.Cast>( GenerateIntermediateMarshallingS IdentifierName("NativeValueStorage"))))))); } - protected override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); yield return ExpressionStatement( @@ -77,7 +77,7 @@ protected override IEnumerable GeneratePreUnmarshallingStatemen .AddArgumentListArguments(Argument(numElementsExpression))); } - protected override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); // MemoryMarshal.Cast>(.ManagedValues); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index e32ea09f196e..b5b2cbd1b030 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -27,7 +27,7 @@ public ContiguousCollectionNonBlittableElementsMarshallingGenerator( this.numElementsExpression = numElementsExpression; } - protected override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { yield return Argument(SizeOfExpression(elementMarshaller.AsNativeType(elementInfo))); } @@ -64,7 +64,7 @@ private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(string nat })))))))))); } - internal StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) + private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); @@ -88,12 +88,12 @@ internal StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo i elementSubContext))))); } - protected override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) { yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); } - protected override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); yield return ExpressionStatement( @@ -105,7 +105,7 @@ protected override IEnumerable GeneratePreUnmarshallingStatemen .AddArgumentListArguments(Argument(numElementsExpression))); } - protected override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index 50faf74d1c55..fce8c225d7d6 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -38,7 +38,7 @@ public CustomNativeTypeMarshaller(GeneratedNativeMarshallingAttributeInfo marsha _marshalerTypePinnable = false; } - protected string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) + public string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) { var (_, nativeIdentifier) = context.GetIdentifiers(info); return _useValueProperty ? nativeIdentifier + MarshalerLocalSuffix : nativeIdentifier; @@ -77,7 +77,7 @@ public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) return Argument(IdentifierName(identifier)); } - public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + public virtual IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); string marshalerIdentifier = GetMarshallerIdentifier(info, context); @@ -269,27 +269,27 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } } - protected virtual IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + public virtual IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { return Array.Empty(); } - protected virtual IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + public virtual IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) { return Array.Empty(); } - protected virtual IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public virtual IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { return Array.Empty(); } - protected virtual IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public virtual IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { return Array.Empty(); } - public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + public virtual bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { if (info.IsManagedReturnPosition || info.IsByRef && info.RefKind != RefKind.In) { @@ -309,6 +309,6 @@ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) return true; } - public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) => false; + public virtual bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) => false; } } From 7ceeadf21645122463b1813d5ee132dcc89d23ed Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 27 May 2021 17:15:41 -0700 Subject: [PATCH 04/30] Move arrays over to collection-based marshalling. --- .../Marshalling/ArrayMarshaller.cs | 4 +- .../Marshalling/BlittableArrayMarshaller.cs | 313 ----------------- ...onBlittableElementsMarshallingGenerator.cs | 5 +- ...onBlittableElementsMarshallingGenerator.cs | 2 +- .../Marshalling/MarshallingGenerator.cs | 108 +++--- .../NonBlittableArrayMarshaller.cs | 317 ------------------ .../MarshallingAttributeInfo.cs | 68 ++-- .../DllImportGenerator/TypeNames.cs | 2 + .../DllImportGenerator/TypePositionInfo.cs | 96 ++++-- 9 files changed, 171 insertions(+), 744 deletions(-) delete mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs delete mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index d14dcdf1411b..7cbb3db6f503 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -16,7 +16,7 @@ internal class ArrayMarshaller : CustomNativeTypeMarshaller public ArrayMarshaller( ContiguousCollectionBlittableElementsMarshallingGenerator innerCollectionMarshaller, - NativeMarshallingAttributeInfo marshallingInfo) + NativeContiguousCollectionMarshallingInfo marshallingInfo) : base(marshallingInfo) { this.innerCollectionMarshaller = innerCollectionMarshaller; @@ -25,7 +25,7 @@ public ArrayMarshaller( public ArrayMarshaller( ContiguousCollectionNonBlittableElementsMarshallingGenerator innerCollectionMarshaller, - NativeMarshallingAttributeInfo marshallingInfo) + NativeContiguousCollectionMarshallingInfo marshallingInfo) : base(marshallingInfo) { this.innerCollectionMarshaller = innerCollectionMarshaller; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs deleted file mode 100644 index cd42a96d8324..000000000000 --- a/DllImportGenerator/DllImportGenerator/Marshalling/BlittableArrayMarshaller.cs +++ /dev/null @@ -1,313 +0,0 @@ -using System.Collections.Generic; - -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - internal class BlittableArrayMarshaller : ConditionalStackallocMarshallingGenerator - { - /// - /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. - /// Number kept small to ensure that P/Invokes with a lot of small array parameters doesn't - /// blow the stack since this is a new optimization in the code-generated interop. - /// - private const int StackAllocBytesThreshold = 0x200; - private readonly ExpressionSyntax _numElementsExpr; - - public BlittableArrayMarshaller(ExpressionSyntax numElementsExpr) - { - _numElementsExpr = numElementsExpr; - } - - private TypeSyntax GetElementTypeSyntax(TypePositionInfo info) - { - return ((IArrayTypeSymbol)info.ManagedType).ElementType.AsTypeSyntax(); - } - - public override TypeSyntax AsNativeType(TypePositionInfo info) - { - return PointerType(GetElementTypeSyntax(info)); - } - - public override ParameterSyntax AsParameter(TypePositionInfo info) - { - var type = info.IsByRef - ? PointerType(AsNativeType(info)) - : AsNativeType(info); - return Parameter(Identifier(info.InstanceIdentifier)) - .WithType(type); - } - - public override ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) - { - return info.IsByRef - ? Argument( - PrefixUnaryExpression( - SyntaxKind.AddressOfExpression, - IdentifierName(context.GetIdentifiers(info).native))) - : Argument(IdentifierName(context.GetIdentifiers(info).native)); - } - - public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) - { - var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); - if (!info.IsByRef && !info.IsManagedReturnPosition && context.PinningSupported) - { - string byRefIdentifier = $"__byref_{managedIdentifer}"; - if (context.CurrentStage == StubCodeContext.Stage.Marshal) - { - // [COMPAT] We use explicit byref calculations here instead of just using a fixed statement - // since a fixed statement converts a zero-length array to a null pointer. - // Many native APIs, such as GDI+, ICU, etc. validate that an array parameter is non-null - // even when the passed in array length is zero. To avoid breaking customers that want to move - // to source-generated interop in subtle ways, we explicitly pass a reference to the 0-th element - // of an array as long as it is non-null, matching the behavior of the built-in interop system - // for single-dimensional zero-based arrays. - - // ref = == null ? ref *(); - var nullRef = - PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, - CastExpression( - PointerType(GetElementTypeSyntax(info)), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))); - - var getArrayDataReference = - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), - IdentifierName("GetArrayDataReference")), - ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(managedIdentifer))))); - - yield return LocalDeclarationStatement( - VariableDeclaration( - RefType(GetElementTypeSyntax(info))) - .WithVariables(SingletonSeparatedList( - VariableDeclarator(Identifier(byRefIdentifier)) - .WithInitializer(EqualsValueClause( - RefExpression(ParenthesizedExpression( - ConditionalExpression( - BinaryExpression( - SyntaxKind.EqualsExpression, - IdentifierName(managedIdentifer), - LiteralExpression( - SyntaxKind.NullLiteralExpression)), - RefExpression(nullRef), - RefExpression(getArrayDataReference))))))))); - } - if (context.CurrentStage == StubCodeContext.Stage.Pin) - { - // fixed ( = &) - yield return FixedStatement( - VariableDeclaration(AsNativeType(info), SingletonSeparatedList( - VariableDeclarator(nativeIdentifier) - .WithInitializer(EqualsValueClause( - PrefixUnaryExpression(SyntaxKind.AddressOfExpression, - IdentifierName(byRefIdentifier)))))), - EmptyStatement()); - } - yield break; - } - - TypeSyntax spanElementTypeSyntax = GetElementTypeSyntax(info); - if (spanElementTypeSyntax is PointerTypeSyntax) - { - // Pointers cannot be passed to generics, so use IntPtr for this case. - spanElementTypeSyntax = ParseTypeName("System.IntPtr"); - } - - switch (context.CurrentStage) - { - case StubCodeContext.Stage.Setup: - if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup)) - yield return conditionalAllocSetup; - - break; - case StubCodeContext.Stage.Marshal: - if (info.RefKind != RefKind.Out) - { - foreach (var statement in GenerateConditionalAllocationSyntax( - info, - context, - StackAllocBytesThreshold)) - { - yield return statement; - } - - // new Span(nativeIdentifier, managedIdentifier.Length) - var nativeSpan = ObjectCreationExpression( - GenericName(TypeNames.System_Span) - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList(spanElementTypeSyntax)))) - .WithArgumentList( - ArgumentList( - SeparatedList( - new []{ - Argument( - CastExpression( - PointerType(spanElementTypeSyntax), - IdentifierName(nativeIdentifier))), - Argument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(managedIdentifer), - IdentifierName("Length"))) - }))); - - // new Span(managedIdentifier).CopyTo(); - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifer), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ObjectCreationExpression( - GenericName(Identifier(TypeNames.System_Span), - TypeArgumentList( - SingletonSeparatedList( - spanElementTypeSyntax)))) - .WithArgumentList( - ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(managedIdentifer))))), - IdentifierName("CopyTo"))) - .WithArgumentList( - ArgumentList( - SingletonSeparatedList( - Argument(nativeSpan)))))); - } - break; - case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition - || (info.IsByRef && info.RefKind != RefKind.In) - || info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) - { - // new Span(nativeIdentifier, managedIdentifier.Length).CopyTo(managedIdentifier); - var unmarshalContentsStatement = - ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ObjectCreationExpression( - GenericName(Identifier(TypeNames.System_Span), - TypeArgumentList( - SingletonSeparatedList( - spanElementTypeSyntax)))) - .WithArgumentList( - ArgumentList( - SeparatedList( - new[]{ - Argument(CastExpression( - PointerType(spanElementTypeSyntax), - IdentifierName(nativeIdentifier))), - Argument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(managedIdentifer), - IdentifierName("Length"))) - }))), - IdentifierName("CopyTo"))) - .WithArgumentList( - ArgumentList( - SingletonSeparatedList( - Argument(IdentifierName(managedIdentifer)))))); - - if (info.IsManagedReturnPosition || info.IsByRef) - { - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(nativeIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block( - // = new []; - ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), - ArrayCreationExpression( - ArrayType(GetElementTypeSyntax(info), - SingletonList(ArrayRankSpecifier( - SingletonSeparatedList(_numElementsExpr))))))), - unmarshalContentsStatement), - ElseClause( - ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), - LiteralExpression(SyntaxKind.NullLiteralExpression))))); - } - else - { - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifer), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - unmarshalContentsStatement); - } - - } - break; - case StubCodeContext.Stage.Cleanup: - yield return GenerateConditionalAllocationFreeSyntax(info, context); - break; - } - } - - public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) - { - return (info.IsByRef || info.IsManagedReturnPosition) || !context.PinningSupported; - } - - protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, out bool allocationRequiresByteLength) - { - allocationRequiresByteLength = true; - // ()Marshal.AllocCoTaskMem() - return CastExpression(AsNativeType(info), - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("AllocCoTaskMem")), - ArgumentList(SingletonSeparatedList(Argument(IdentifierName(byteLengthIdentifier)))))); - } - - protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context) - { - // checked(sizeof() * .Length) - return CheckedExpression(SyntaxKind.CheckedExpression, - BinaryExpression(SyntaxKind.MultiplyExpression, - SizeOfExpression(GetElementTypeSyntax(info)), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(context.GetIdentifiers(info).managed), - IdentifierName("Length")))); - } - - protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier) - { - return EmptyStatement(); - } - - protected override ExpressionSyntax GenerateFreeExpression(TypePositionInfo info, StubCodeContext context) - { - // Marshal.FreeCoTaskMem((IntPtr)) - return InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("FreeCoTaskMem")), - ArgumentList(SingletonSeparatedList( - Argument( - CastExpression( - ParseTypeName("System.IntPtr"), - IdentifierName(context.GetIdentifiers(info).native)))))); - } - - public override bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) - { - return !context.PinningSupported && marshalKind.HasFlag(ByValueContentsMarshalKind.Out); - } - } - -} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index 1e26bc0f433e..df482fffeafa 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -14,12 +14,11 @@ class ContiguousCollectionBlittableElementsMarshallingGenerator : CustomNativeTy private readonly ExpressionSyntax numElementsExpression; public ContiguousCollectionBlittableElementsMarshallingGenerator( - NativeMarshallingAttributeInfo marshallingInfo, - ITypeSymbol elementType, + NativeContiguousCollectionMarshallingInfo marshallingInfo, ExpressionSyntax numElementsExpression) :base(marshallingInfo) { - this.elementType = elementType; + this.elementType = marshallingInfo.ElementType; this.numElementsExpression = numElementsExpression; } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index b5b2cbd1b030..b728c68aea34 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -16,7 +16,7 @@ class ContiguousCollectionNonBlittableElementsMarshallingGenerator : CustomNativ private readonly ExpressionSyntax numElementsExpression; public ContiguousCollectionNonBlittableElementsMarshallingGenerator( - NativeMarshallingAttributeInfo marshallingInfo, + NativeContiguousCollectionMarshallingInfo marshallingInfo, IMarshallingGenerator elementMarshaller, TypePositionInfo elementInfo, ExpressionSyntax numElementsExpression) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index c8f13da707c3..28588aa988dd 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -250,7 +250,7 @@ private static IMarshallingGenerator CreateCore( // Must go before the cases that do not explicitly check for marshalling info to support // the user overridding the default marshalling rules with a MarshalUsing attribute. case { MarshallingAttributeInfo: NativeMarshallingAttributeInfo marshalInfo }: - return CreateCustomNativeTypeMarshaller(info, context, marshalInfo); + return CreateCustomNativeTypeMarshaller(info, context, marshalInfo, options); case { MarshallingAttributeInfo: BlittableTypeAttributeInfo }: return Blittable; @@ -266,9 +266,6 @@ private static IMarshallingGenerator CreateCore( case { ManagedType: { SpecialType: SpecialType.System_String } }: return CreateStringMarshaller(info, context); - - case { ManagedType: IArrayTypeSymbol { IsSZArray: true, ElementType: ITypeSymbol elementType } }: - return CreateArrayMarshaller(info, context, options, elementType); case { ManagedType: { SpecialType: SpecialType.System_Void } }: return Forwarder; @@ -366,24 +363,29 @@ private static IMarshallingGenerator CreateStringMarshaller(TypePositionInfo inf throw new MarshallingNotSupportedException(info, context); } - private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, StubCodeContext context, AnalyzerConfigOptions options) + private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, CountInfo count, StubCodeContext context, AnalyzerConfigOptions options) { - ExpressionSyntax numElementsExpression; - if (info.MarshallingAttributeInfo is not ArrayMarshalAsInfo marshalAsInfo) + return count switch { - throw new MarshallingNotSupportedException(info, context) + SizeAndParamIndexInfo(int size, SizeAndParamIndexInfo.UnspecifiedData) => GetConstSizeExpression(size), + ConstSizeCountInfo(int size) => GetConstSizeExpression(size), + SizeAndParamIndexInfo(SizeAndParamIndexInfo.UnspecifiedData, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParamIndex(paramIndex)), + SizeAndParamIndexInfo(int size, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, GetConstSizeExpression(size), GetExpressionForParamIndex(paramIndex))), + CountElementCountInfo(string elementName) => throw new NotImplementedException(), + _ => throw new MarshallingNotSupportedException(info, context) { NotSupportedDetails = Resources.ArraySizeMustBeSpecified - }; + }, + }; + + static LiteralExpressionSyntax GetConstSizeExpression(int size) + { + return LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(size)); } - LiteralExpressionSyntax? constSizeExpression = marshalAsInfo.ArraySizeConst != ArrayMarshalAsInfo.UnspecifiedData - ? LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(marshalAsInfo.ArraySizeConst)) - : null; - ExpressionSyntax? sizeParamIndexExpression = null; - if (marshalAsInfo.ArraySizeParamIndex != ArrayMarshalAsInfo.UnspecifiedData) + ExpressionSyntax GetExpressionForParamIndex(int index) { - TypePositionInfo? paramIndexInfo = context.GetTypePositionInfoForManagedIndex(marshalAsInfo.ArraySizeParamIndex); + TypePositionInfo? paramIndexInfo = context.GetTypePositionInfoForManagedIndex(index); if (paramIndexInfo is null) { throw new MarshallingNotSupportedException(info, context) @@ -402,51 +404,14 @@ private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(Type { var (managed, native) = context.GetIdentifiers(paramIndexInfo); string identifier = Create(paramIndexInfo, context, options).UsesNativeIdentifier(paramIndexInfo, context) ? native : managed; - sizeParamIndexExpression = CastExpression( + return CastExpression( PredefinedType(Token(SyntaxKind.IntKeyword)), IdentifierName(identifier)); } } - numElementsExpression = (constSizeExpression, sizeParamIndexExpression) switch - { - (null, null) => throw new MarshallingNotSupportedException(info, context) - { - NotSupportedDetails = Resources.ArraySizeMustBeSpecified - }, - (not null, null) => constSizeExpression!, - (null, not null) => CheckedExpression(SyntaxKind.CheckedExpression, sizeParamIndexExpression!), - (not null, not null) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, constSizeExpression!, sizeParamIndexExpression!)) - }; - return numElementsExpression; - } - - private static IMarshallingGenerator CreateArrayMarshaller(TypePositionInfo info, StubCodeContext context, AnalyzerConfigOptions options, ITypeSymbol elementType) - { - var elementMarshallingInfo = info.MarshallingAttributeInfo switch - { - ArrayMarshalAsInfo(UnmanagedType.LPArray, _) marshalAs => marshalAs.ElementMarshallingInfo, - ArrayMarshallingInfo marshalInfo => marshalInfo.ElementMarshallingInfo, - NoMarshallingInfo _ => NoMarshallingInfo.Instance, - _ => throw new MarshallingNotSupportedException(info, context) - }; - - var elementMarshaller = Create( - TypePositionInfo.CreateForType(elementType, elementMarshallingInfo), - new ArrayMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, context, false), - options); - ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)); - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) - { - // In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here. - numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, context, options); - } - - return elementMarshaller == Blittable - ? new BlittableArrayMarshaller(numElementsExpression) - : new NonBlittableArrayMarshaller(elementMarshaller, numElementsExpression); } - private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo) + private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo, AnalyzerConfigOptions options) { if (marshalInfo.ValuePropertyType is not null && !context.CanUseAdditionalTemporaryState) { @@ -504,8 +469,43 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) }; } + + if (marshalInfo is NativeContiguousCollectionMarshallingInfo collectionMarshallingInfo) + { + return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, options); + } return new CustomNativeTypeMarshaller(marshalInfo); } + + private static IMarshallingGenerator CreateNativeCollectionMarshaller(TypePositionInfo info, StubCodeContext context, NativeContiguousCollectionMarshallingInfo collectionMarshallingInfo, AnalyzerConfigOptions options) + { + var elementInfo = TypePositionInfo.CreateForType(collectionMarshallingInfo.ElementType, collectionMarshallingInfo.ElementMarshallingInfo); + var elementMarshaller = Create( + elementInfo, + new ContiguousCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, string.Empty, context), + options); + ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)); + if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + { + // In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here. + numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionMarshallingInfo.ElementCountInfo, context, options); + } + + if (collectionMarshallingInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true }) + { + if (elementMarshaller == Blittable) + { + return new ArrayMarshaller(new ContiguousCollectionBlittableElementsMarshallingGenerator(collectionMarshallingInfo, numElementsExpression), collectionMarshallingInfo); + } + return new ArrayMarshaller(new ContiguousCollectionNonBlittableElementsMarshallingGenerator(collectionMarshallingInfo, elementMarshaller, elementInfo, numElementsExpression), collectionMarshallingInfo); + } + + if (elementMarshaller == Blittable) + { + return new ContiguousCollectionBlittableElementsMarshallingGenerator(collectionMarshallingInfo, numElementsExpression); + } + return new ContiguousCollectionNonBlittableElementsMarshallingGenerator(collectionMarshallingInfo, elementMarshaller, elementInfo, numElementsExpression); + } } } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs deleted file mode 100644 index 40d793c73101..000000000000 --- a/DllImportGenerator/DllImportGenerator/Marshalling/NonBlittableArrayMarshaller.cs +++ /dev/null @@ -1,317 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - internal class NonBlittableArrayMarshaller : ConditionalStackallocMarshallingGenerator - { - /// - /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. - /// Number kept small to ensure that P/Invokes with a lot of small array parameters doesn't - /// blow the stack since this is a new optimization in the code-generated interop. - /// - private const int StackAllocBytesThreshold = 0x200; - - private const string IndexerIdentifier = "__i"; - - private readonly IMarshallingGenerator _elementMarshaller; - private readonly ExpressionSyntax _numElementsExpr; - - public NonBlittableArrayMarshaller(IMarshallingGenerator elementMarshaller, ExpressionSyntax numElementsExpr) - { - _elementMarshaller = elementMarshaller; - _numElementsExpr = numElementsExpr; - } - - private ITypeSymbol GetElementTypeSymbol(TypePositionInfo info) - { - return ((IArrayTypeSymbol)info.ManagedType).ElementType; - } - - private TypeSyntax GetNativeElementTypeSyntax(TypePositionInfo info) - { - return _elementMarshaller.AsNativeType(TypePositionInfo.CreateForType(GetElementTypeSymbol(info), NoMarshallingInfo.Instance)); - } - - public override TypeSyntax AsNativeType(TypePositionInfo info) - { - return PointerType(GetNativeElementTypeSyntax(info)); - } - - public override ParameterSyntax AsParameter(TypePositionInfo info) - { - var type = info.IsByRef - ? PointerType(AsNativeType(info)) - : AsNativeType(info); - return Parameter(Identifier(info.InstanceIdentifier)) - .WithType(type); - } - - public override ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) - { - return info.IsByRef - ? Argument( - PrefixUnaryExpression( - SyntaxKind.AddressOfExpression, - IdentifierName(context.GetIdentifiers(info).native))) - : Argument(IdentifierName(context.GetIdentifiers(info).native)); - } - - public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) - { - var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); - RefKind elementRefKind = info.IsByRef ? info.RefKind : info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind(); - bool cacheManagedValue = ShouldCacheManagedValue(info, context); - string managedLocal = !cacheManagedValue ? managedIdentifer : managedIdentifer + ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; - - switch (context.CurrentStage) - { - case StubCodeContext.Stage.Setup: - if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup)) - yield return conditionalAllocSetup; - - if (cacheManagedValue) - { - yield return LocalDeclarationStatement( - VariableDeclaration( - info.ManagedType.AsTypeSyntax(), - SingletonSeparatedList( - VariableDeclarator(managedLocal) - .WithInitializer(EqualsValueClause( - IdentifierName(managedIdentifer)))))); - } - break; - case StubCodeContext.Stage.Marshal: - if (info.RefKind != RefKind.Out) - { - foreach (var statement in GenerateConditionalAllocationSyntax( - info, - context, - StackAllocBytesThreshold)) - { - yield return statement; - } - - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); - - TypeSyntax spanElementTypeSyntax = GetNativeElementTypeSyntax(info); - if (spanElementTypeSyntax is PointerTypeSyntax) - { - // Pointers cannot be passed to generics, so use IntPtr for this case. - spanElementTypeSyntax = ParseTypeName("System.IntPtr"); - } - - if (info is { IsByRef: false, ByValueContentsMarshalKind: ByValueContentsMarshalKind.Out }) - { - // We don't marshal values from managed to native for [Out] by value arrays, - // we only allocate the buffer. - yield break; - } - - // Iterate through the elements of the array to marshal them - yield return IfStatement(BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedLocal), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block( - // new Span(, .Length).Clear(); - ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ObjectCreationExpression( - GenericName(TypeNames.System_Span) - .WithTypeArgumentList( - TypeArgumentList( - SingletonSeparatedList(spanElementTypeSyntax)))) - .WithArgumentList( - ArgumentList( - SeparatedList( - new []{ - Argument( - CastExpression( - PointerType(spanElementTypeSyntax), - IdentifierName(nativeIdentifier))), - Argument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(managedIdentifer), - IdentifierName("Length"))) - }))), - IdentifierName("Clear")), - ArgumentList())), - MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) - .WithStatement(Block( - List(_elementMarshaller.Generate( - info with { ManagedType = GetElementTypeSymbol(info), RefKind = elementRefKind }, - arraySubContext)))))); - } - break; - case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition - || (info.IsByRef && info.RefKind != RefKind.In) - || info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) - { - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); - // Iterate through the elements of the native array to unmarshal them - StatementSyntax unmarshalContentsStatement = - MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) - .WithStatement(Block( - List(_elementMarshaller.Generate( - info with { ManagedType = GetElementTypeSymbol(info), RefKind = elementRefKind }, - arraySubContext)))); - - if (!info.IsByRef) - { - if (info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) - { - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedLocal), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - unmarshalContentsStatement); - - if (cacheManagedValue) - { - yield return ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), - IdentifierName(managedLocal)) - ); - } - yield break; - } - } - - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(nativeIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - Block( - // = new []; - ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedLocal), - ArrayCreationExpression( - ArrayType(GetElementTypeSymbol(info).AsTypeSyntax(), - SingletonList(ArrayRankSpecifier( - SingletonSeparatedList(_numElementsExpr))))))), - unmarshalContentsStatement - ), - ElseClause( - ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedLocal), - LiteralExpression(SyntaxKind.NullLiteralExpression))))); - - if (cacheManagedValue) - { - yield return ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifer), - IdentifierName(managedLocal)) - ); - } - } - break; - case StubCodeContext.Stage.Cleanup: - { - var arraySubContext = new ArrayMarshallingCodeContext(context.CurrentStage, IndexerIdentifier, context, appendLocalManagedIdentifierSuffix: cacheManagedValue); - var elementCleanup = List(_elementMarshaller.Generate(info with { ManagedType = GetElementTypeSymbol(info), RefKind = elementRefKind }, arraySubContext)); - if (elementCleanup.Count != 0) - { - // Iterate through the elements of the native array to clean up any unmanaged resources. - yield return IfStatement( - BinaryExpression(SyntaxKind.NotEqualsExpression, - IdentifierName(managedLocal), - LiteralExpression(SyntaxKind.NullLiteralExpression)), - MarshallerHelpers.GetForLoop(managedLocal, IndexerIdentifier) - .WithStatement(Block(elementCleanup))); - } - yield return GenerateConditionalAllocationFreeSyntax(info, context); - } - break; - } - } - - private static bool ShouldCacheManagedValue(TypePositionInfo info, StubCodeContext context) - { - return info.IsByRef && context.CanUseAdditionalTemporaryState; - } - - public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) - { - return true; - } - - - protected override ExpressionSyntax GenerateAllocationExpression(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, out bool allocationRequiresByteLength) - { - allocationRequiresByteLength = true; - // (*)Marshal.AllocCoTaskMem() - return CastExpression(AsNativeType(info), - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("AllocCoTaskMem")), - ArgumentList(SingletonSeparatedList(Argument(IdentifierName(byteLengthIdentifier)))))); - } - - protected override ExpressionSyntax GenerateByteLengthCalculationExpression(TypePositionInfo info, StubCodeContext context) - { - string managedIdentifier = context.GetIdentifiers(info).managed; - if (ShouldCacheManagedValue(info, context)) - { - managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; - } - // checked(sizeof() * .Length) - return CheckedExpression(SyntaxKind.CheckedExpression, - BinaryExpression(SyntaxKind.MultiplyExpression, - SizeOfExpression(GetNativeElementTypeSyntax(info)), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(managedIdentifier), - IdentifierName("Length")))); - } - - protected override StatementSyntax GenerateStackallocOnlyValueMarshalling(TypePositionInfo info, StubCodeContext context, SyntaxToken byteLengthIdentifier, SyntaxToken stackAllocPtrIdentifier) - { - return EmptyStatement(); - } - - protected override ExpressionSyntax GenerateFreeExpression(TypePositionInfo info, StubCodeContext context) - { - // Marshal.FreeCoTaskMem((IntPtr)) - return InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("FreeCoTaskMem")), - ArgumentList(SingletonSeparatedList( - Argument( - CastExpression( - ParseTypeName("System.IntPtr"), - IdentifierName(context.GetIdentifiers(info).native)))))); - } - - public override bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) - { - return marshalKind.HasFlag(ByValueContentsMarshalKind.Out); - } - - protected override ExpressionSyntax GenerateNullCheckExpression(TypePositionInfo info, StubCodeContext context) - { - string managedIdentifier = context.GetIdentifiers(info).managed; - if (ShouldCacheManagedValue(info, context)) - { - managedIdentifier += ArrayMarshallingCodeContext.LocalManagedIdentifierSuffix; - } - - return BinaryExpression( - SyntaxKind.NotEqualsExpression, - IdentifierName(managedIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)); - } - } -} diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index a9c036948d56..d03e8d539aab 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -43,31 +43,12 @@ CharEncoding CharEncoding /// /// Simple User-application of System.Runtime.InteropServices.MarshalAsAttribute /// - internal record MarshalAsInfo( + internal sealed record MarshalAsInfo( UnmanagedType UnmanagedType, CharEncoding CharEncoding) : MarshallingInfoStringSupport(CharEncoding) { } - enum UnmanagedArrayType - { - LPArray = UnmanagedType.LPArray, - ByValArray = UnmanagedType.ByValArray - } - - /// - /// User-applied System.Runtime.InteropServices.MarshalAsAttribute with array marshalling info - /// - internal sealed record ArrayMarshalAsInfo( - UnmanagedArrayType UnmanagedArrayType, - int ArraySizeConst, - short ArraySizeParamIndex, - CharEncoding CharEncoding, - MarshallingInfo ElementMarshallingInfo) : MarshalAsInfo((UnmanagedType)UnmanagedArrayType, CharEncoding) - { - public const short UnspecifiedData = -1; - } - /// /// User-applied System.Runtime.InteropServices.BlittableTypeAttribute /// or System.Runtime.InteropServices.GeneratedMarshallingAttribute on a blittable type @@ -82,16 +63,38 @@ internal enum SupportedMarshallingMethods NativeToManaged = 0x2, ManagedToNativeStackalloc = 0x4, Pinning = 0x8, + All = -1 + } + + internal abstract record CountInfo; + + internal sealed record NoCountInfo : CountInfo + { + public static readonly NoCountInfo Instance = new NoCountInfo(); + + private NoCountInfo() { } + } + + internal sealed record ConstSizeCountInfo(int Size) : CountInfo; + + internal sealed record CountElementCountInfo(string parameterName) : CountInfo; + + internal sealed record SizeAndParamIndexInfo(int ConstSize, int ParamIndex) : CountInfo + { + public const int UnspecifiedData = -1; + + public static readonly SizeAndParamIndexInfo Unspecified = new(UnspecifiedData, UnspecifiedData); } /// /// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute /// - internal sealed record NativeMarshallingAttributeInfo( + internal record NativeMarshallingAttributeInfo( ITypeSymbol NativeMarshallingType, ITypeSymbol? ValuePropertyType, SupportedMarshallingMethods MarshallingMethods, - bool NativeTypePinnable) : MarshallingInfo; + bool NativeTypePinnable, + bool UseDefaultMarshalling) : MarshallingInfo; /// /// User-applied System.Runtime.InteropServices.GeneratedMarshallingAttribute @@ -105,9 +108,22 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo( /// internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo; - /// - /// Default marshalling for arrays - /// - internal sealed record ArrayMarshallingInfo(MarshallingInfo ElementMarshallingInfo) : MarshallingInfo; + /// User-applied System.Runtime.InteropServices.NativeMarshalllingAttribute + /// with a contiguous collection marshaller + internal sealed record NativeContiguousCollectionMarshallingInfo( + ITypeSymbol NativeMarshallingType, + ITypeSymbol? ValuePropertyType, + SupportedMarshallingMethods MarshallingMethods, + bool NativeTypePinnable, + bool UseDefaultMarshalling, + CountInfo ElementCountInfo, + ITypeSymbol ElementType, + MarshallingInfo ElementMarshallingInfo) : NativeMarshallingAttributeInfo( + NativeMarshallingType, + ValuePropertyType, + MarshallingMethods, + NativeTypePinnable, + UseDefaultMarshalling + ); } diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index 509969a05230..e737f487393d 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -36,6 +36,8 @@ public static string MarshalEx(AnalyzerConfigOptions options) { return options.UseMarshalType() ? System_Runtime_InteropServices_Marshal : System_Runtime_InteropServices_MarshalEx; } + + public const string System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata = "System.Runtime.InteropServices.GeneratedMarshalling.ArrayMarshaller`1"; public const string System_Runtime_InteropServices_MemoryMarshal = "System.Runtime.InteropServices.MemoryMarshal"; diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 8d23b274ee3a..6c14a5ebd6da 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -127,7 +127,7 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo m return typeInfo; } - private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) + private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, int indirectionLevel = 0) { // Look at attributes passed in - usage specific. foreach (var attrData in attributes) @@ -137,11 +137,11 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) { // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol); + return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) { - return CreateNativeMarshallingInfo(type, compilation, attrData, allowGetPinnableReference: false); + return CreateNativeMarshallingInfo(type, compilation, attrData, useDefaultMarshalling: false, indirectionLevel); } } @@ -162,7 +162,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) { - return CreateNativeMarshallingInfo(type, compilation, attrData, allowGetPinnableReference: true); + return CreateNativeMarshallingInfo(type, compilation, attrData, useDefaultMarshalling: true, indirectionLevel); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) { @@ -172,7 +172,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< // If the type doesn't have custom attributes that dictate marshalling, // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe)) + if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe, indirectionLevel)) { return infoMaybe; } @@ -188,7 +188,14 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< return NoMarshallingInfo.Instance; - static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) + static MarshallingInfo CreateMarshalAsInfo( + ITypeSymbol type, + AttributeData attrData, + DefaultMarshallingInfo defaultInfo, + Compilation compilation, + GeneratorDiagnostics diagnostics, + INamedTypeSymbol scopeSymbol, + int indirectionLevel) { object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; UnmanagedType unmanagedType = unmanagedTypeObj is short @@ -201,9 +208,8 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); } bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; - UnmanagedType unmanagedArraySubType = (UnmanagedType)ArrayMarshalAsInfo.UnspecifiedData; - int arraySizeConst = ArrayMarshalAsInfo.UnspecifiedData; - short arraySizeParamIndex = ArrayMarshalAsInfo.UnspecifiedData; + UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData; + SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; // All other data on attribute is defined as NamedArguments. foreach (var namedArg in attrData.NamedArguments) @@ -226,21 +232,21 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat { diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); } - unmanagedArraySubType = (UnmanagedType)namedArg.Value.Value!; + elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; break; case nameof(MarshalAsAttribute.SizeConst): if (!isArrayType) { diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); } - arraySizeConst = (int)namedArg.Value.Value!; + arraySizeInfo = arraySizeInfo with { ConstSize = (int)namedArg.Value.Value! }; break; case nameof(MarshalAsAttribute.SizeParamIndex): if (!isArrayType) { diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); } - arraySizeParamIndex = (short)namedArg.Value.Value!; + arraySizeInfo = arraySizeInfo with { ParamIndex = (short)namedArg.Value.Value! }; break; } } @@ -250,26 +256,42 @@ static MarshalAsInfo CreateMarshalAsInfo(ITypeSymbol type, AttributeData attrDat return new MarshalAsInfo(unmanagedType, defaultInfo.CharEncoding); } + if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); + return NoMarshallingInfo.Instance; + } + MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; - if (unmanagedArraySubType != (UnmanagedType)ArrayMarshalAsInfo.UnspecifiedData) + if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData) + { + elementMarshallingInfo = new MarshalAsInfo(elementUnmanagedType, defaultInfo.CharEncoding); + } + else { - elementMarshallingInfo = new MarshalAsInfo(unmanagedArraySubType, defaultInfo.CharEncoding); + elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel++); } - else if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + + INamedTypeSymbol? arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + + if (arrayMarshaller is null) { - elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol); + // If the array marshaler type is not available, then we cannot marshal arrays. + return NoMarshallingInfo.Instance; } - return new ArrayMarshalAsInfo( - UnmanagedArrayType: (UnmanagedArrayType)unmanagedType, - ArraySizeConst: arraySizeConst, - ArraySizeParamIndex: arraySizeParamIndex, - CharEncoding: defaultInfo.CharEncoding, - ElementMarshallingInfo: elementMarshallingInfo - ); + return new NativeContiguousCollectionMarshallingInfo( + NativeMarshallingType: arrayMarshaller, + ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, + MarshallingMethods: SupportedMarshallingMethods.All, + NativeTypePinnable : true, + UseDefaultMarshalling: true, + ElementCountInfo: arraySizeInfo, + ElementType: elementType, + ElementMarshallingInfo: elementMarshallingInfo); } - static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, AttributeData attrData, bool allowGetPinnableReference) + static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, AttributeData attrData, bool useDefaultMarshalling, int indirectionLevel) { ITypeSymbol spanOfByte = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(compilation.GetSpecialType(SpecialType.System_Byte)); INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; @@ -295,7 +317,7 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty methods |= SupportedMarshallingMethods.NativeToManaged; } - if (allowGetPinnableReference && ManualTypeMarshallingHelper.FindGetPinnableReference(type) != null) + if (useDefaultMarshalling && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) { methods |= SupportedMarshallingMethods.Pinning; } @@ -309,10 +331,11 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty nativeType, valueProperty?.Type, methods, - NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null); + NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + UseDefaultMarshalling: useDefaultMarshalling); } - static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo) + static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo, int indirectionLevel) { var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); if (conversion.Exists @@ -337,7 +360,24 @@ static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshalli if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) { - marshallingInfo = new ArrayMarshallingInfo(GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol)); + INamedTypeSymbol? arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays. + marshallingInfo = NoMarshallingInfo.Instance; + return false; + } + + marshallingInfo = new NativeContiguousCollectionMarshallingInfo( + NativeMarshallingType: arrayMarshaller!, + ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller!)?.Type, + MarshallingMethods: SupportedMarshallingMethods.All, + NativeTypePinnable: true, + UseDefaultMarshalling: true, + ElementCountInfo: NoCountInfo.Instance, + ElementType: elementType, + ElementMarshallingInfo: GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel + 1)); return true; } From d26f77225e3ab5030687d3689ab9191fc588b883 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 28 May 2021 11:24:16 -0700 Subject: [PATCH 05/30] Fix unit test failures and delete dead code. Add special marshaller for arrays of pointers to ensure they are still supported. --- .../Ancillary.Interop/ArrayMarshaller.cs | 104 +++++++++++++++++- .../ArrayMarshallingCodeContext.cs | 81 -------------- ...CollectionElementMarshallingCodeContext.cs | 2 +- .../Marshalling/ArrayMarshaller.cs | 45 +++----- ...onBlittableElementsMarshallingGenerator.cs | 4 +- ...onBlittableElementsMarshallingGenerator.cs | 63 +++++++++-- .../Marshalling/CustomNativeTypeMarshaller.cs | 9 +- .../Marshalling/MarshallerHelpers.cs | 13 ++- .../DllImportGenerator/TypeNames.cs | 8 +- .../DllImportGenerator/TypePositionInfo.cs | 26 ++++- 10 files changed, 218 insertions(+), 137 deletions(-) delete mode 100644 DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs diff --git a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs index 5ff669bd8a95..9d5cba5b591e 100644 --- a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs +++ b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs @@ -43,7 +43,7 @@ public ArrayMarshaller(T[]? managed, Span stackSpace, int sizeOfNativeElem int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); if (spaceToAllocate < stackSpace.Length) { - NativeValueStorage = stackSpace.Slice(spaceToAllocate); + NativeValueStorage = stackSpace[0..spaceToAllocate]; } else { @@ -86,9 +86,9 @@ public byte* Value } else { + allocatedMemory = (IntPtr)value; NativeValueStorage = new Span(value, (managedArray?.Length ?? 0) * sizeOfNativeElement); } - } } @@ -102,4 +102,104 @@ public void FreeNative() } } } + + public unsafe ref struct PtrArrayMarshaller where T : unmanaged + { + private T*[]? managedArray; + private readonly int sizeOfNativeElement; + private IntPtr allocatedMemory; + + public PtrArrayMarshaller(T*[]? managed, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedArray = null; + NativeValueStorage = default; + return; + } + managedArray = managed; + this.sizeOfNativeElement = sizeOfNativeElement; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + } + + public PtrArrayMarshaller(T*[]? managed, Span stackSpace, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedArray = null; + NativeValueStorage = default; + return; + } + managedArray = managed; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); + if (spaceToAllocate < stackSpace.Length) + { + NativeValueStorage = stackSpace[0..spaceToAllocate]; + } + else + { + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + } + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of array parameters doesn't + /// blow the stack since this is a new optimization in the code-generated interop. + /// + public const int StackBufferSize = 0x200; + + public Span ManagedValues => Unsafe.As(managedArray); + + public Span NativeValueStorage { get; private set; } + + public ref byte GetPinnableReference() => ref MemoryMarshal.GetReference(NativeValueStorage); + + public void SetUnmarshalledCollectionLength(int length) + { + managedArray = new T*[length]; + } + + public byte* Value + { + get + { + Debug.Assert(managedArray is null || allocatedMemory != IntPtr.Zero); + return (byte*)allocatedMemory; + } + set + { + if (value == null) + { + managedArray = null; + NativeValueStorage = default; + } + else + { + allocatedMemory = (IntPtr)value; + NativeValueStorage = new Span(value, (managedArray?.Length ?? 0) * sizeOfNativeElement); + } + + } + } + + public T*[]? ToManaged() => managedArray; + + public void FreeNative() + { + if (allocatedMemory != IntPtr.Zero) + { + Marshal.FreeCoTaskMem(allocatedMemory); + } + } + } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs deleted file mode 100644 index 9e8d1ab5a200..000000000000 --- a/DllImportGenerator/DllImportGenerator/ArrayMarshallingCodeContext.cs +++ /dev/null @@ -1,81 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; - -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - internal sealed class ArrayMarshallingCodeContext : StubCodeContext - { - public const string LocalManagedIdentifierSuffix = "_local"; - - private readonly string indexerIdentifier; - private readonly StubCodeContext parentContext; - private readonly bool appendLocalManagedIdentifierSuffix; - - public override bool PinningSupported => false; - - public override bool StackSpaceUsable => false; - - /// - /// Additional variables other than the {managedIdentifier} and {nativeIdentifier} variables - /// can be added to the stub to track additional state for the marshaller in the stub. - /// - /// - /// Currently, array scenarios do not support declaring additional temporary variables to support - /// marshalling. This can be accomplished in the future with some additional infrastructure to support - /// declaring arrays additional arrays in the stub to support the temporary state. - /// - public override bool CanUseAdditionalTemporaryState => false; - - /// - /// Create a for marshalling elements of an array. - /// - /// The current marshalling stage. - /// The indexer in the loop to get the element to marshal from the array. - /// The parent context. - /// - /// For array marshalling, we sometimes cache the array in a local to avoid multithreading issues. - /// Set this to true to add the to the managed identifier when - /// marshalling the array elements to ensure that we use the local copy instead of the managed identifier - /// when marshalling elements. - /// - public ArrayMarshallingCodeContext( - Stage currentStage, - string indexerIdentifier, - StubCodeContext parentContext, - bool appendLocalManagedIdentifierSuffix) - { - CurrentStage = currentStage; - this.indexerIdentifier = indexerIdentifier; - this.parentContext = parentContext; - this.appendLocalManagedIdentifierSuffix = appendLocalManagedIdentifierSuffix; - } - - /// - /// Get managed and native instance identifiers for the - /// - /// Object for which to get identifiers - /// Managed and native identifiers - public override (string managed, string native) GetIdentifiers(TypePositionInfo info) - { - var (managed, native) = parentContext.GetIdentifiers(info); - if (appendLocalManagedIdentifierSuffix) - { - managed += LocalManagedIdentifierSuffix; - } - return ($"{managed}[{indexerIdentifier}]", $"{native}[{indexerIdentifier}]"); - } - - public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index) - { - // We don't have parameters to look at when we're in the middle of marshalling an array. - return null; - } - } -} diff --git a/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs index e203391abf3d..d7cc9123e731 100644 --- a/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs @@ -59,7 +59,7 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo { var (managed, _) = parentContext.GetIdentifiers(info); return ( - $"{managed}{CustomNativeTypeMarshaller.MarshalerLocalSuffix}.ManagedValues[{indexerIdentifier}]", + $"{MarshallerHelpers.GetMarshallerIdentifier(info, parentContext)}.ManagedValues[{indexerIdentifier}]", $"{nativeSpanIdentifier}[{indexerIdentifier}]" ); } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index 7cbb3db6f503..62100a03d513 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -10,9 +10,8 @@ namespace Microsoft.Interop { internal class ArrayMarshaller : CustomNativeTypeMarshaller { - private CustomNativeTypeMarshaller innerCollectionMarshaller; - - private bool blittable; + private readonly CustomNativeTypeMarshaller innerCollectionMarshaller; + private readonly bool blittable; public ArrayMarshaller( ContiguousCollectionBlittableElementsMarshallingGenerator innerCollectionMarshaller, @@ -29,7 +28,7 @@ public ArrayMarshaller( : base(marshallingInfo) { this.innerCollectionMarshaller = innerCollectionMarshaller; - blittable = true; + blittable = false; } private bool UseCustomPinningPath(TypePositionInfo info, StubCodeContext context) @@ -47,7 +46,10 @@ public override IEnumerable Generate(TypePositionInfo info, Stu if (context.CurrentStage == StubCodeContext.Stage.Unmarshal && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { - return GenerateByValueOutUnmarshalling(); + // For [Out] by value unmarshalling, we emit custom code that only copies the elements. + // We do not call SetUnmarshalledCollectionLength since that creates a new + // array, and we want to fill the original one. + return innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context); } return innerCollectionMarshaller.Generate(info, context); @@ -107,34 +109,19 @@ IEnumerable GenerateCustomPinning() VariableDeclarator(nativeIdentifier) .WithInitializer(EqualsValueClause( PrefixUnaryExpression(SyntaxKind.AddressOfExpression, - IdentifierName(byRefIdentifier)))))), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_CompilerServices_Unsafe), + GenericName("As").AddTypeArgumentListArguments( + arrayElementType, + PredefinedType(Token(SyntaxKind.ByteKeyword))))) + .AddArgumentListArguments( + Argument(IdentifierName(byRefIdentifier)) + .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))))), EmptyStatement()); } yield break; } - - IEnumerable GenerateByValueOutUnmarshalling() - { - var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); - // For [Out] by value unmarshalling, we emit custom code that only assigns the - // Value property and copy the elements. - // We do not call SetUnmarshalledCollectionLength since that creates a new - // array, and we want to fill the original one. - string marshalerIdentifier = innerCollectionMarshaller.GetMarshallerIdentifier(info, context); - // .Value = ; - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), - IdentifierName(nativeIdentifier))); - - foreach (var statement in innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context)) - { - yield return statement; - } - } } public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index df482fffeafa..f0b51249c43b 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -45,7 +45,7 @@ public override IEnumerable GenerateIntermediateMarshallingStat InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("MemoryMarshal"), + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), GenericName( Identifier("Cast")) .WithTypeArgumentList( @@ -87,7 +87,7 @@ public override IEnumerable GenerateIntermediateUnmarshallingSt InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("MemoryMarshal"), + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), GenericName( Identifier("Cast")) .WithTypeArgumentList( diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index b728c68aea34..c060626bd9b8 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -37,13 +37,14 @@ private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext co return context.GetIdentifiers(info).managed + "__nativeSpan"; } - private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(string nativeSpanIdentifier) - { + private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositionInfo info, StubCodeContext context) + { + string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); return LocalDeclarationStatement(VariableDeclaration( GenericName( Identifier(TypeNames.System_Span), TypeArgumentList( - SingletonSeparatedList(PredefinedType(Token(SyntaxKind.ByteKeyword)))) + SingletonSeparatedList(elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax())) ), SingletonSeparatedList( VariableDeclarator(Identifier(nativeSpanIdentifier)) @@ -51,7 +52,7 @@ private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(string nat InvocationExpression( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("MemoryMarshal"), + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), GenericName( Identifier("Cast")) .WithTypeArgumentList( @@ -60,8 +61,12 @@ private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(string nat new [] { PredefinedType(Token(SyntaxKind.ByteKeyword)), - elementMarshaller.AsNativeType(elementInfo) - })))))))))); + elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax() + }))))) + .AddArgumentListArguments( + Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(GetMarshallerIdentifier(info, context)), + IdentifierName("NativeValueStorage"))))))))); } private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) @@ -78,14 +83,24 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in ? $"{marshalerIdentifier}.ManagedValues" : nativeSpanIdentifier; + TypePositionInfo localElementInfo = elementInfo with { InstanceIdentifier = info.InstanceIdentifier, RefKind = info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind() }; + + StatementSyntax marshallingStatement = Block( + List(elementMarshaller.Generate( + localElementInfo, + elementSubContext))); + + if (elementMarshaller.AsNativeType(elementInfo) is PointerTypeSyntax) + { + PointerNativeTypeAssignmentRewriter rewriter = new(elementSubContext.GetIdentifiers(localElementInfo).native); + marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement); + } + // Iterate through the elements of the native collection to unmarshal them return Block( - GenerateNativeSpanDeclaration(GetNativeSpanIdentifier(info, context)), + GenerateNativeSpanDeclaration(info, context), MarshallerHelpers.GetForLoop(collectionIdentifierForLength, IndexerIdentifier) - .WithStatement(Block( - List(elementMarshaller.Generate( - elementInfo, - elementSubContext))))); + .WithStatement(marshallingStatement)); } public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) @@ -109,5 +124,31 @@ public override IEnumerable GenerateIntermediateUnmarshallingSt { yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); } + + /// + /// Rewrite assignment expressions to the native identifier to cast to IntPtr. + /// This handles the case where the native type of a non-blittable managed type is a pointer, + /// which are unsupported in generic type parameters. + /// + private class PointerNativeTypeAssignmentRewriter : CSharpSyntaxRewriter + { + private readonly string nativeIdentifier; + + public PointerNativeTypeAssignmentRewriter(string nativeIdentifier) + { + this.nativeIdentifier = nativeIdentifier; + } + + public override SyntaxNode VisitAssignmentExpression(AssignmentExpressionSyntax node) + { + if (node.Left.ToString() == nativeIdentifier) + { + return node.WithRight( + CastExpression(ParseTypeName("System.IntPtr"), node.Right)); + } + + return node; + } + } } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index fce8c225d7d6..8abc59cb1912 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -10,7 +10,6 @@ namespace Microsoft.Interop { class CustomNativeTypeMarshaller : IMarshallingGenerator { - internal const string MarshalerLocalSuffix = "__marshaler"; private readonly TypeSyntax _nativeTypeSyntax; private readonly TypeSyntax _nativeLocalTypeSyntax; private readonly SupportedMarshallingMethods _marshallingMethods; @@ -40,8 +39,9 @@ public CustomNativeTypeMarshaller(GeneratedNativeMarshallingAttributeInfo marsha public string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) { - var (_, nativeIdentifier) = context.GetIdentifiers(info); - return _useValueProperty ? nativeIdentifier + MarshalerLocalSuffix : nativeIdentifier; + return _useValueProperty + ? MarshallerHelpers.GetMarshallerIdentifier(info, context) + : context.GetIdentifiers(info).native; } public TypeSyntax AsNativeType(TypePositionInfo info) @@ -166,9 +166,10 @@ public virtual IEnumerable Generate(TypePositionInfo info, Stub _nativeLocalTypeSyntax, IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName))) }))))); - arguments.AddRange(GenerateAdditionalNativeTypeConstructorArguments(info, context)); } + arguments.AddRange(GenerateAdditionalNativeTypeConstructorArguments(info, context)); + // = new <_nativeLocalType>(); yield return ExpressionStatement( AssignmentExpression( diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs index 9d293fa338a4..e0529fe55b05 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallerHelpers.cs @@ -15,6 +15,8 @@ internal static class MarshallerHelpers public static readonly TypeSyntax InteropServicesMarshalType = ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal); + public static readonly TypeSyntax SystemIntPtrType = ParseTypeName("System.IntPtr"); + public static ForStatementSyntax GetForLoop(string collectionIdentifier, string indexerIdentifier) { // for(int = 0; < .Length; ++) @@ -79,11 +81,18 @@ public static TypeSyntax GetCompatibleGenericTypeParameterSyntax(this TypeSyntax if (spanElementTypeSyntax is PointerTypeSyntax) { // Pointers cannot be passed to generics, so use IntPtr for this case. - spanElementTypeSyntax = ParseTypeName("System.IntPtr"); + spanElementTypeSyntax = SystemIntPtrType; } return spanElementTypeSyntax; } + private const string MarshalerLocalSuffix = "__marshaler"; + public static string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) + { + var (_, nativeIdentifier) = context.GetIdentifiers(info); + return nativeIdentifier + MarshalerLocalSuffix; + } + public static class StringMarshaller { public static ExpressionSyntax AllocationExpression(CharEncoding encoding, string managedIdentifier) @@ -122,7 +131,7 @@ public static ExpressionSyntax FreeExpression(string nativeIdentifier) ArgumentList(SingletonSeparatedList( Argument( CastExpression( - ParseTypeName("System.IntPtr"), + SystemIntPtrType, IdentifierName(nativeIdentifier)))))); } } diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index e737f487393d..f296e9da69e1 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -38,7 +38,9 @@ public static string MarshalEx(AnalyzerConfigOptions options) } public const string System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata = "System.Runtime.InteropServices.GeneratedMarshalling.ArrayMarshaller`1"; - + + public const string System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata = "System.Runtime.InteropServices.GeneratedMarshalling.PtrArrayMarshaller`1"; + public const string System_Runtime_InteropServices_MemoryMarshal = "System.Runtime.InteropServices.MemoryMarshal"; public const string System_Runtime_InteropServices_SafeHandle = "System.Runtime.InteropServices.SafeHandle"; @@ -48,5 +50,9 @@ public static string MarshalEx(AnalyzerConfigOptions options) public const string System_Runtime_InteropServices_InAttribute = "System.Runtime.InteropServices.InAttribute"; public const string System_Runtime_CompilerServices_SkipLocalsInitAttribute = "System.Runtime.CompilerServices.SkipLocalsInitAttribute"; + + // TODO: Add configuration for using Internal.Runtime.CompilerServices.Unsafe to support + // running against System.Private.CoreLib + public const string System_Runtime_CompilerServices_Unsafe = "System.Runtime.CompilerServices.Unsafe"; } } diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 6c14a5ebd6da..3d37ca966bab 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -272,7 +272,16 @@ static MarshallingInfo CreateMarshalAsInfo( elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel++); } - INamedTypeSymbol? arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + INamedTypeSymbol? arrayMarshaller; + + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + } if (arrayMarshaller is null) { @@ -283,7 +292,7 @@ static MarshallingInfo CreateMarshalAsInfo( return new NativeContiguousCollectionMarshallingInfo( NativeMarshallingType: arrayMarshaller, ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, - MarshallingMethods: SupportedMarshallingMethods.All, + MarshallingMethods: ~SupportedMarshallingMethods.Pinning, NativeTypePinnable : true, UseDefaultMarshalling: true, ElementCountInfo: arraySizeInfo, @@ -360,7 +369,16 @@ static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshalli if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) { - INamedTypeSymbol? arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + INamedTypeSymbol? arrayMarshaller; + + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + } if (arrayMarshaller is null) { @@ -372,7 +390,7 @@ static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshalli marshallingInfo = new NativeContiguousCollectionMarshallingInfo( NativeMarshallingType: arrayMarshaller!, ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller!)?.Type, - MarshallingMethods: SupportedMarshallingMethods.All, + MarshallingMethods: ~SupportedMarshallingMethods.Pinning, NativeTypePinnable: true, UseDefaultMarshalling: true, ElementCountInfo: NoCountInfo.Instance, From 1dd69a07f0cef729db8278f49ff4dc9dce58d532 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 28 May 2021 12:08:30 -0700 Subject: [PATCH 06/30] Fix integration tests. Extend collection marshallers design with an additional constructor to specify native element size for the unmarshal-only case. --- .../Ancillary.Interop/ArrayMarshaller.cs | 20 ++++++++++++++---- ...onBlittableElementsMarshallingGenerator.cs | 14 ++++++++++++- ...onBlittableElementsMarshallingGenerator.cs | 21 +++++++++++++++++-- .../Marshalling/StringMarshaller.Ansi.cs | 2 +- .../Marshalling/StringMarshaller.Utf16.cs | 2 +- designs/SpanMarshallers.md | 2 ++ 6 files changed, 52 insertions(+), 9 deletions(-) diff --git a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs index 9d5cba5b591e..4710fb8337cb 100644 --- a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs +++ b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs @@ -10,6 +10,12 @@ public unsafe ref struct ArrayMarshaller private readonly int sizeOfNativeElement; private IntPtr allocatedMemory; + public ArrayMarshaller(int sizeOfNativeElement) + :this() + { + this.sizeOfNativeElement = sizeOfNativeElement; + } + public ArrayMarshaller(T[]? managed, int sizeOfNativeElement) { allocatedMemory = default; @@ -25,7 +31,7 @@ public ArrayMarshaller(T[]? managed, int sizeOfNativeElement) // Always allocate at least one byte when the array is zero-length. int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); - NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); } public ArrayMarshaller(T[]? managed, Span stackSpace, int sizeOfNativeElement) @@ -48,7 +54,7 @@ public ArrayMarshaller(T[]? managed, Span stackSpace, int sizeOfNativeElem else { allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); - NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); } } @@ -109,6 +115,12 @@ public void FreeNative() private readonly int sizeOfNativeElement; private IntPtr allocatedMemory; + public PtrArrayMarshaller(int sizeOfNativeElement) + : this() + { + this.sizeOfNativeElement = sizeOfNativeElement; + } + public PtrArrayMarshaller(T*[]? managed, int sizeOfNativeElement) { allocatedMemory = default; @@ -124,7 +136,7 @@ public PtrArrayMarshaller(T*[]? managed, int sizeOfNativeElement) // Always allocate at least one byte when the array is zero-length. int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); - NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); } public PtrArrayMarshaller(T*[]? managed, Span stackSpace, int sizeOfNativeElement) @@ -147,7 +159,7 @@ public PtrArrayMarshaller(T*[]? managed, Span stackSpace, int sizeOfNative else { allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); - NativeValueStorage = new Span((void*)allocatedMemory, managed.Length); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); } } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index f0b51249c43b..d6be98aee426 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -22,9 +22,14 @@ public ContiguousCollectionBlittableElementsMarshallingGenerator( this.numElementsExpression = numElementsExpression; } + private ExpressionSyntax GenerateSizeOfElementExpression() + { + return SizeOfExpression(elementType.AsTypeSyntax()); + } + public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { - yield return Argument(SizeOfExpression(elementType.AsTypeSyntax())); + yield return Argument(GenerateSizeOfElementExpression()); } public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) @@ -67,6 +72,13 @@ public override IEnumerable GenerateIntermediateMarshallingStat public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); + if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) + { + yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(marshalerIdentifier), + ImplicitObjectCreationExpression() + .AddArgumentListArguments(Argument(GenerateSizeOfElementExpression())))); + } yield return ExpressionStatement( InvocationExpression( MemberAccessExpression( diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index c060626bd9b8..d2d09e143aaf 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -27,9 +27,14 @@ public ContiguousCollectionNonBlittableElementsMarshallingGenerator( this.numElementsExpression = numElementsExpression; } + private ExpressionSyntax GenerateSizeOfElementExpression() + { + return SizeOfExpression(elementMarshaller.AsNativeType(elementInfo)); + } + public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) { - yield return Argument(SizeOfExpression(elementMarshaller.AsNativeType(elementInfo))); + yield return Argument(GenerateSizeOfElementExpression()); } private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext context) @@ -83,7 +88,13 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in ? $"{marshalerIdentifier}.ManagedValues" : nativeSpanIdentifier; - TypePositionInfo localElementInfo = elementInfo with { InstanceIdentifier = info.InstanceIdentifier, RefKind = info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind() }; + TypePositionInfo localElementInfo = elementInfo with + { + InstanceIdentifier = info.InstanceIdentifier, + RefKind = info.IsByRef ? info.RefKind : info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind(), + ManagedIndex = info.ManagedIndex, + NativeIndex = info.NativeIndex + }; StatementSyntax marshallingStatement = Block( List(elementMarshaller.Generate( @@ -111,6 +122,12 @@ public override IEnumerable GenerateIntermediateMarshallingStat public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) { string marshalerIdentifier = GetMarshallerIdentifier(info, context); + if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) + { + yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(marshalerIdentifier), + ImplicitObjectCreationExpression().AddArgumentListArguments(Argument(GenerateSizeOfElementExpression())))); + } yield return ExpressionStatement( InvocationExpression( MemberAccessExpression( diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Ansi.cs b/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Ansi.cs index b12dcfcd0231..96b7c615dfc3 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Ansi.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Ansi.cs @@ -126,7 +126,7 @@ public override IEnumerable Generate(TypePositionInfo info, Stu BinaryExpression( SyntaxKind.EqualsExpression, IdentifierName(nativeIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), + LiteralExpression(SyntaxKind.DefaultLiteralExpression)), LiteralExpression(SyntaxKind.NullLiteralExpression), ObjectCreationExpression( PredefinedType(Token(SyntaxKind.StringKeyword)), diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Utf16.cs b/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Utf16.cs index ffe989fddd2e..c5a50066f8f6 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Utf16.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/StringMarshaller.Utf16.cs @@ -112,7 +112,7 @@ public override IEnumerable Generate(TypePositionInfo info, Stu BinaryExpression( SyntaxKind.EqualsExpression, IdentifierName(nativeIdentifier), - LiteralExpression(SyntaxKind.NullLiteralExpression)), + LiteralExpression(SyntaxKind.DefaultLiteralExpression)), LiteralExpression(SyntaxKind.NullLiteralExpression), ObjectCreationExpression( PredefinedType(Token(SyntaxKind.StringKeyword)), diff --git a/designs/SpanMarshallers.md b/designs/SpanMarshallers.md index 4fa80be494fe..a1b3e74f672f 100644 --- a/designs/SpanMarshallers.md +++ b/designs/SpanMarshallers.md @@ -94,6 +94,8 @@ A generic collection marshaller would be required to have the following shape, i [GenericContiguousCollectionMarshaller] public struct GenericContiguousCollectionMarshallerImpl { + // this constructor is required if marshalling from native to managed is supported. + public GenericContiguousCollectionMarshallerImpl(int nativeSizeOfElement); // these constructors are required if marshalling from managed to native is supported. public GenericContiguousCollectionMarshallerImpl(GenericCollection collection, int nativeSizeOfElement); public GenericContiguousCollectionMarshallerImpl(GenericCollection collection, Span stackSpace, int nativeSizeOfElement); // optional From 8c8210534f29c3b72b4d04849494d68d95c15902 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 28 May 2021 13:36:53 -0700 Subject: [PATCH 07/30] Implement element cleanup. --- .../DllImportGenerator/Marshalling/ArrayMarshaller.cs | 5 +++++ ...CollectionNonBlittableElementsMarshallingGenerator.cs | 5 +++++ .../Marshalling/CustomNativeTypeMarshaller.cs | 9 +++++++++ .../Marshalling/SafeHandleMarshaller.cs | 3 +-- 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index 62100a03d513..7649919fbfb1 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -149,6 +149,11 @@ public override IEnumerable GenerateIntermediateUnmarshallingSt return innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context); } + public override IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerCollectionMarshaller.GenerateIntermediateCleanupStatements(info, context); + } + public override bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { return !(blittable && context.PinningSupported) && marshalKind.HasFlag(ByValueContentsMarshalKind.Out); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index d2d09e143aaf..8165109a6595 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -142,6 +142,11 @@ public override IEnumerable GenerateIntermediateUnmarshallingSt yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); } + public override IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); + } + /// /// Rewrite assignment expressions to the native identifier to cast to IntPtr. /// This handles the case where the native type of a non-blittable managed type is a pointer, diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index 8abc59cb1912..9296d709bcb5 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -256,6 +256,10 @@ public virtual IEnumerable Generate(TypePositionInfo info, Stub case StubCodeContext.Stage.Cleanup: if (info.RefKind != RefKind.Out && _hasFreeNative) { + foreach (var statement in GenerateIntermediateCleanupStatements(info, context)) + { + yield return statement; + } // .FreeNative(); yield return ExpressionStatement( InvocationExpression( @@ -290,6 +294,11 @@ public virtual IEnumerable GenerateIntermediateUnmarshallingSta return Array.Empty(); } + public virtual IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + public virtual bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { if (info.IsManagedReturnPosition || info.IsByRef && info.RefKind != RefKind.In) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs index 7fa6513186c1..58e062aa57c2 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/SafeHandleMarshaller.cs @@ -10,7 +10,6 @@ namespace Microsoft.Interop { internal class SafeHandleMarshaller : IMarshallingGenerator { - private static readonly TypeSyntax NativeType = ParseTypeName("global::System.IntPtr"); private readonly AnalyzerConfigOptions options; public SafeHandleMarshaller(AnalyzerConfigOptions options) @@ -20,7 +19,7 @@ public SafeHandleMarshaller(AnalyzerConfigOptions options) public TypeSyntax AsNativeType(TypePositionInfo info) { - return NativeType; + return MarshallerHelpers.SystemIntPtrType; } public ParameterSyntax AsParameter(TypePositionInfo info) From 51a5e0dda5254c1d022289b56cfc0d1f0dc2fe57 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 28 May 2021 15:42:27 -0700 Subject: [PATCH 08/30] Add support for parsing count info and indirection info out of MarshalUsingAttribute. --- .../GeneratedMarshallingAttribute.cs | 24 ++- .../DllImportGenerator/DllImportStub.cs | 4 +- .../ManualTypeMarshallingHelper.cs | 23 ++- .../Marshalling/MarshallingGenerator.cs | 17 +- .../MarshallingAttributeInfo.cs | 6 +- .../DllImportGenerator/StubCodeGenerator.cs | 6 +- .../DllImportGenerator/TypeNames.cs | 2 + .../DllImportGenerator/TypePositionInfo.cs | 184 +++++++++++++++--- 8 files changed, 218 insertions(+), 48 deletions(-) diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs index 11e1b29639fb..29088afc44da 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs @@ -25,11 +25,33 @@ public NativeMarshallingAttribute(Type nativeType) [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.ReturnValue | AttributeTargets.Field)] public class MarshalUsingAttribute : Attribute { + public MarshalUsingAttribute() + { + CountElementName = null!; + } + public MarshalUsingAttribute(Type nativeType) + :this() { NativeType = nativeType; } - public Type NativeType { get; } + public Type? NativeType { get; } + + public string CountElementName { get; set; } + + public int ConstantElementCount { get; set; } + + public int ElementIndirectionLevel { get; set; } + + public const string ReturnsCountValue = "return-value"; + } + + [AttributeUsage(AttributeTargets.Struct | AttributeTargets.Class)] + public sealed class GenericContiguousCollectionMarshallerAttribute : Attribute + { + public GenericContiguousCollectionMarshallerAttribute() + { + } } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index cff9d7f41d83..52bff59b84dd 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -162,7 +162,7 @@ public static DllImportStub Create( for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; - var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method.ContainingType); + var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method); typeInfo = typeInfo with { ManagedIndex = i, @@ -171,7 +171,7 @@ public static DllImportStub Create( paramsTypeInfo.Add(typeInfo); } - TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method.ContainingType); + TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index eb825f1c7a61..905e85ae6f9e 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -23,8 +23,18 @@ public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol manage && !m.IsStatic); } - public static bool IsManagedToNativeConstructor(IMethodSymbol ctor, ITypeSymbol managedType) + public static bool IsManagedToNativeConstructor( + IMethodSymbol ctor, + ITypeSymbol managedType, + ITypeSymbol int32, + bool isCollectionMarshaller) { + if (isCollectionMarshaller) + { + return ctor.Parameters.Length == 2 + && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) + && SymbolEqualityComparer.Default.Equals(int32, ctor.Parameters[1].Type); + } return ctor.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type); } @@ -32,8 +42,17 @@ public static bool IsManagedToNativeConstructor(IMethodSymbol ctor, ITypeSymbol public static bool IsStackallocConstructor( IMethodSymbol ctor, ITypeSymbol managedType, - ITypeSymbol spanOfByte) + ITypeSymbol spanOfByte, + ITypeSymbol int32, + bool isCollectionMarshaller) { + if (isCollectionMarshaller) + { + return ctor.Parameters.Length == 3 + && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) + && SymbolEqualityComparer.Default.Equals(spanOfByte, ctor.Parameters[1].Type) + && SymbolEqualityComparer.Default.Equals(int32, ctor.Parameters[2].Type); + } return ctor.Parameters.Length == 2 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) && SymbolEqualityComparer.Default.Equals(spanOfByte, ctor.Parameters[1].Type); diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index 28588aa988dd..e31615c405de 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -369,9 +369,9 @@ private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(Type { SizeAndParamIndexInfo(int size, SizeAndParamIndexInfo.UnspecifiedData) => GetConstSizeExpression(size), ConstSizeCountInfo(int size) => GetConstSizeExpression(size), - SizeAndParamIndexInfo(SizeAndParamIndexInfo.UnspecifiedData, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParamIndex(paramIndex)), - SizeAndParamIndexInfo(int size, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, GetConstSizeExpression(size), GetExpressionForParamIndex(paramIndex))), - CountElementCountInfo(string elementName) => throw new NotImplementedException(), + SizeAndParamIndexInfo(SizeAndParamIndexInfo.UnspecifiedData, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParam(context.GetTypePositionInfoForManagedIndex(paramIndex))), + SizeAndParamIndexInfo(int size, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, GetConstSizeExpression(size), GetExpressionForParam(context.GetTypePositionInfoForManagedIndex(paramIndex)))), + CountElementCountInfo(TypePositionInfo elementInfo) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParam(elementInfo)), _ => throw new MarshallingNotSupportedException(info, context) { NotSupportedDetails = Resources.ArraySizeMustBeSpecified @@ -383,17 +383,16 @@ static LiteralExpressionSyntax GetConstSizeExpression(int size) return LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(size)); } - ExpressionSyntax GetExpressionForParamIndex(int index) + ExpressionSyntax GetExpressionForParam(TypePositionInfo? paramInfo) { - TypePositionInfo? paramIndexInfo = context.GetTypePositionInfoForManagedIndex(index); - if (paramIndexInfo is null) + if (paramInfo is null) { throw new MarshallingNotSupportedException(info, context) { NotSupportedDetails = Resources.ArraySizeParamIndexOutOfRange }; } - else if (!paramIndexInfo.ManagedType.IsIntegralType()) + else if (!paramInfo.ManagedType.IsIntegralType()) { throw new MarshallingNotSupportedException(info, context) { @@ -402,8 +401,8 @@ ExpressionSyntax GetExpressionForParamIndex(int index) } else { - var (managed, native) = context.GetIdentifiers(paramIndexInfo); - string identifier = Create(paramIndexInfo, context, options).UsesNativeIdentifier(paramIndexInfo, context) ? native : managed; + var (managed, native) = context.GetIdentifiers(paramInfo); + string identifier = Create(paramInfo, context, options).UsesNativeIdentifier(paramInfo, context) ? native : managed; return CastExpression( PredefinedType(Token(SyntaxKind.IntKeyword)), IdentifierName(identifier)); diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index d03e8d539aab..c782ad2ed60e 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -59,6 +59,7 @@ internal sealed record BlittableTypeAttributeInfo : MarshallingInfo; [Flags] internal enum SupportedMarshallingMethods { + None = 0, ManagedToNative = 0x1, NativeToManaged = 0x2, ManagedToNativeStackalloc = 0x4, @@ -77,7 +78,10 @@ private NoCountInfo() { } internal sealed record ConstSizeCountInfo(int Size) : CountInfo; - internal sealed record CountElementCountInfo(string parameterName) : CountInfo; + internal sealed record CountElementCountInfo(TypePositionInfo ElementInfo) : CountInfo + { + public const string ReturnValueElementName = "return-value"; + } internal sealed record SizeAndParamIndexInfo(int ConstSize, int ParamIndex) : CountInfo { diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index ea5995414f89..0c72508a3686 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -94,7 +94,7 @@ public StubCodeGenerator( public override (string managed, string native) GetIdentifiers(TypePositionInfo info) { - if (info.IsManagedReturnPosition && !info.IsNativeReturnPosition) + if (info.IsManagedReturnPosition) { return (ReturnIdentifier, ReturnNativeIdentifier); } @@ -102,10 +102,6 @@ public override (string managed, string native) GetIdentifiers(TypePositionInfo { return (InvokeReturnIdentifier, InvokeReturnIdentifier); } - else if (info.IsManagedReturnPosition && info.IsNativeReturnPosition) - { - return (ReturnIdentifier, ReturnNativeIdentifier); - } else { // If the info isn't in either the managed or native return position, diff --git a/DllImportGenerator/DllImportGenerator/TypeNames.cs b/DllImportGenerator/DllImportGenerator/TypeNames.cs index f296e9da69e1..ad75ffadb8c5 100644 --- a/DllImportGenerator/DllImportGenerator/TypeNames.cs +++ b/DllImportGenerator/DllImportGenerator/TypeNames.cs @@ -17,6 +17,8 @@ static class TypeNames public const string MarshalUsingAttribute = "System.Runtime.InteropServices.MarshalUsingAttribute"; + public const string GenericContiguousCollectionMarshallerAttribute = "System.Runtime.InteropServices.GenericContiguousCollectionMarshallerAttribute"; + public const string LCIDConversionAttribute = "System.Runtime.InteropServices.LCIDConversionAttribute"; public const string System_Span_Metadata = "System.Span`1"; diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 3d37ca966bab..aee676e53305 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; @@ -82,9 +83,9 @@ private TypePositionInfo() public MarshallingInfo MarshallingAttributeInfo { get; init; } - public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) + public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, IMethodSymbol methodSymbol) { - var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, scopeSymbol); + var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, methodSymbol); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, @@ -98,9 +99,9 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, return typeInfo; } - public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol) + public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, ISymbol symbol) { - var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, scopeSymbol); + var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, symbol); var typeInfo = new TypePositionInfo() { ManagedType = type, @@ -127,21 +128,28 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo m return typeInfo; } - private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, int indirectionLevel = 0) + private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, ISymbol contextSymbol, int indirectionLevel = 0) { + CountInfo parsedCountInfo = NoCountInfo.Instance; // Look at attributes passed in - usage specific. foreach (var attrData in attributes) { INamedTypeSymbol attributeClass = attrData.AttributeClass!; - if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) + if (indirectionLevel == 0 + && SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) { // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel); + return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, indirectionLevel); } - else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) + else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass) + && AttributeAppliesToCurrentIndirectionLevel(attrData, indirectionLevel)) { - return CreateNativeMarshallingInfo(type, compilation, attrData, useDefaultMarshalling: false, indirectionLevel); + parsedCountInfo = CreateCountInfo(attrData); + if (attrData.ConstructorArguments.Length != 0) + { + return CreateNativeMarshallingInfo(type, compilation, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); + } } } @@ -162,7 +170,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) { - return CreateNativeMarshallingInfo(type, compilation, attrData, useDefaultMarshalling: true, indirectionLevel); + return CreateNativeMarshallingInfo(type, compilation, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) { @@ -172,7 +180,14 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< // If the type doesn't have custom attributes that dictate marshalling, // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo(type, defaultInfo, compilation, diagnostics, scopeSymbol, out MarshallingInfo infoMaybe, indirectionLevel)) + if (TryCreateTypeBasedMarshallingInfo( + type, + defaultInfo, + compilation, + diagnostics, + contextSymbol.ContainingType, + indirectionLevel, + out MarshallingInfo infoMaybe)) { return infoMaybe; } @@ -188,13 +203,11 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< return NoMarshallingInfo.Instance; - static MarshallingInfo CreateMarshalAsInfo( + MarshallingInfo CreateMarshalAsInfo( ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, Compilation compilation, - GeneratorDiagnostics diagnostics, - INamedTypeSymbol scopeSymbol, int indirectionLevel) { object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; @@ -269,7 +282,7 @@ static MarshallingInfo CreateMarshalAsInfo( } else { - elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel++); + elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, contextSymbol, indirectionLevel++); } INamedTypeSymbol? arrayMarshaller; @@ -300,38 +313,97 @@ static MarshallingInfo CreateMarshalAsInfo( ElementMarshallingInfo: elementMarshallingInfo); } - static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, AttributeData attrData, bool useDefaultMarshalling, int indirectionLevel) + CountInfo CreateCountInfo(AttributeData marshalUsingData) + { + int? constSize = null; + string? elementName = null; + foreach (var arg in marshalUsingData.NamedArguments) + { + if (arg.Key == "ConstantElementCount") + { + constSize = (int)arg.Value.Value!; + } + else if (arg.Key == "CountElementName") + { + if (arg.Value.Value is null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", "null"); + return NoCountInfo.Instance; + } + elementName = (string)arg.Value.Value!; + } + } + + if (constSize is not null && elementName is not null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "ConstantElementCount and CountElementName combined"); + } + else if (constSize is not null) + { + return new ConstSizeCountInfo(constSize.Value); + } + else if (elementName is not null) + { + TypePositionInfo? elementInfo = CreateForElementName(compilation, diagnostics, defaultInfo, contextSymbol, elementName); + if (elementInfo is null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", elementName); + return NoCountInfo.Instance; + } + return new CountElementCountInfo(elementInfo); + } + + return NoCountInfo.Instance; + } + + static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) { + SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; + + if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) + { + methods |= SupportedMarshallingMethods.Pinning; + } + ITypeSymbol spanOfByte = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(compilation.GetSpecialType(SpecialType.System_Byte)); + ITypeSymbol int32 = compilation.GetSpecialType(SpecialType.System_Int32); + INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; - SupportedMarshallingMethods methods = 0; + + ITypeSymbol contiguousCollectionMarshalerAttribute = compilation.GetTypeByMetadataName(TypeNames.GenericContiguousCollectionMarshallerAttribute)!; + + bool isContiguousCollectionMarshaller = nativeType.GetAttributes().Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, contiguousCollectionMarshalerAttribute)); IPropertySymbol? valueProperty = ManualTypeMarshallingHelper.FindValueProperty(nativeType); + + bool hasInt32Constructor = false; foreach (var ctor in nativeType.Constructors) { - if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type) + if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, int32, isCollectionMarshaller: true) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNative; } - else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte) + else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, int32, isCollectionMarshaller: true) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; } + else if (ctor.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(ctor.Parameters[0], int32)) + { + hasInt32Constructor = true; + } } - if (ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) + // The constructor that takes only the native element size is required for collection marshallers + // in the native-to-managed scenario. + if ((!isContiguousCollectionMarshaller || hasInt32Constructor) + && ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) && (valueProperty is null or { SetMethod: not null })) { methods |= SupportedMarshallingMethods.NativeToManaged; } - if (useDefaultMarshalling && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) - { - methods |= SupportedMarshallingMethods.Pinning; - } - - if (methods == 0) + if (methods == SupportedMarshallingMethods.None) { // TODO: Diagnostic since no marshalling methods are supported. } @@ -341,10 +413,17 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty valueProperty?.Type, methods, NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, - UseDefaultMarshalling: useDefaultMarshalling); + UseDefaultMarshalling: !isMarshalUsingAttribute); } - static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, out MarshallingInfo marshallingInfo, int indirectionLevel) + static bool TryCreateTypeBasedMarshallingInfo( + ITypeSymbol type, + DefaultMarshallingInfo defaultInfo, + Compilation compilation, + GeneratorDiagnostics diagnostics, + INamedTypeSymbol scopeSymbol, + int indirectionLevel, + out MarshallingInfo marshallingInfo) { var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); if (conversion.Exists @@ -413,6 +492,55 @@ static bool TryCreateTypeBasedMarshallingInfo(ITypeSymbol type, DefaultMarshalli } } + private static TypePositionInfo? CreateForElementName(Compilation compilation, GeneratorDiagnostics diagnostics, DefaultMarshallingInfo defaultInfo, ISymbol context, string elementName) + { + if (context is IMethodSymbol method) + { + if (elementName == CountElementCountInfo.ReturnValueElementName) + { + return CreateForType( + method.ReturnType, + method.GetReturnTypeAttributes(), + defaultInfo, + compilation, + diagnostics, + method) with + { + ManagedIndex = ReturnIndex + }; + } + + foreach (var param in method.Parameters) + { + if (param.Name == elementName) + { + return CreateForParameter(param, defaultInfo, compilation, diagnostics, method); + } + } + } + else if (context is INamedTypeSymbol _) + { + // TODO: Handle when we create a struct marshalling generator + // Do we want to support CountElementName pointing to only fields, or properties as well? + // If only fields, how do we handle properties with generated backing fields? + } + + return null; + } + + private static bool AttributeAppliesToCurrentIndirectionLevel(AttributeData attrData, int indirectionLevel) + { + int elementIndirectionLevel = 0; + foreach (var arg in attrData.NamedArguments) + { + if (arg.Key == "ElementIndirectionLevel") + { + elementIndirectionLevel = (int)arg.Value.Value!; + } + } + return elementIndirectionLevel == indirectionLevel; + } + private static ByValueContentsMarshalKind GetByValueContentsMarshalKind(IEnumerable attributes, Compilation compilation) { INamedTypeSymbol outAttributeType = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_OutAttribute)!; From cbf11f0582d5bece0ef3b725e0c91567d589e493 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 28 May 2021 16:04:49 -0700 Subject: [PATCH 09/30] Handle open generics in NativeMarshallingAttribute. --- .../DllImportGenerator/TypePositionInfo.cs | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index aee676e53305..5acdd6529c4b 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -148,7 +148,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< parsedCountInfo = CreateCountInfo(attrData); if (attrData.ConstructorArguments.Length != 0) { - return CreateNativeMarshallingInfo(type, compilation, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); + return CreateNativeMarshallingInfo(type, compilation, diagnostics, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); } } } @@ -170,7 +170,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) { - return CreateNativeMarshallingInfo(type, compilation, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); + return CreateNativeMarshallingInfo(type, compilation, diagnostics, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) { @@ -356,7 +356,7 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) return NoCountInfo.Instance; } - static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) + static MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, GeneratorDiagnostics diagnostics, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) { SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; @@ -370,6 +370,32 @@ static NativeMarshallingAttributeInfo CreateNativeMarshallingInfo(ITypeSymbol ty INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; + if (nativeType.IsUnboundGenericType) + { + if (isMarshalUsingAttribute) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + else if (type is INamedTypeSymbol namedType) + { + if (namedType.Arity != nativeType.Arity) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + else + { + nativeType = nativeType.Construct(namedType.TypeParameters.ToArray()); + } + } + else + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + } + ITypeSymbol contiguousCollectionMarshalerAttribute = compilation.GetTypeByMetadataName(TypeNames.GenericContiguousCollectionMarshallerAttribute)!; bool isContiguousCollectionMarshaller = nativeType.GetAttributes().Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, contiguousCollectionMarshalerAttribute)); From c147791fc75f383fd4bd568831033bc150c7e133 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 10:33:18 -0700 Subject: [PATCH 10/30] Implement element type lookup. --- .../GeneratedMarshallingAttribute.cs | 2 +- .../ManualTypeMarshallingHelper.cs | 25 +++++++++++ .../DllImportGenerator/TypePositionInfo.cs | 42 +++++++++++++++---- 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs index 29088afc44da..6939ffcf568b 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs @@ -22,7 +22,7 @@ public NativeMarshallingAttribute(Type nativeType) public Type NativeType { get; } } - [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.ReturnValue | AttributeTargets.Field)] + [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.ReturnValue | AttributeTargets.Field, AllowMultiple = true)] public class MarshalUsingAttribute : Attribute { public MarshalUsingAttribute() diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index 905e85ae6f9e..31d5f143142a 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -1,4 +1,5 @@ +using System; using System.Linq; using Microsoft.CodeAnalysis; @@ -11,6 +12,7 @@ static class ManualTypeMarshallingHelper public const string StackBufferSizeFieldName = "StackBufferSize"; public const string ToManagedMethodName = "ToManaged"; public const string FreeNativeMethodName = "FreeNative"; + public const string ManagedValuesPropertyName = "ManagedValues"; public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol managedType) { @@ -84,5 +86,28 @@ public static bool HasFreeNativeMethod(ITypeSymbol type) .Any(m => m is { Parameters: { Length: 0 } } and ({ ReturnType: { SpecialType: SpecialType.System_Void } })); } + + public static IPropertySymbol? FindManagedValuesProperty(ITypeSymbol type) + { + return type + .GetMembers(ManagedValuesPropertyName) + .OfType() + .FirstOrDefault(p => !p.IsStatic); + } + + public static bool TryGetElementTypeFromContiguousCollectionMarshaller(ITypeSymbol type, out ITypeSymbol elementType) + { + IPropertySymbol? managedValuesProperty = FindManagedValuesProperty(type); + + if (managedValuesProperty is null) + { + elementType = null!; + return false; + } + + elementType = ((INamedTypeSymbol)managedValuesProperty.Type).TypeArguments[0]; + return true; + } + } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 5acdd6529c4b..49889ac01c69 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -140,15 +140,20 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< && SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) { // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation, indirectionLevel); + return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass) && AttributeAppliesToCurrentIndirectionLevel(attrData, indirectionLevel)) { + if (parsedCountInfo != NoCountInfo.Instance) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Duplicate Count Info"); + return NoMarshallingInfo.Instance; + } parsedCountInfo = CreateCountInfo(attrData); if (attrData.ConstructorArguments.Length != 0) { - return CreateNativeMarshallingInfo(type, compilation, diagnostics, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); + return CreateNativeMarshallingInfo(type, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); } } } @@ -170,7 +175,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) { - return CreateNativeMarshallingInfo(type, compilation, diagnostics, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); + return CreateNativeMarshallingInfo(type, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); } else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) { @@ -207,8 +212,7 @@ MarshallingInfo CreateMarshalAsInfo( ITypeSymbol type, AttributeData attrData, DefaultMarshallingInfo defaultInfo, - Compilation compilation, - int indirectionLevel) + Compilation compilation) { object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; UnmanagedType unmanagedType = unmanagedTypeObj is short @@ -282,7 +286,8 @@ MarshallingInfo CreateMarshalAsInfo( } else { - elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, contextSymbol, indirectionLevel++); + // Indirection level does not matter since we don't pass down attributes to be inspected. + elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, contextSymbol, 0); } INamedTypeSymbol? arrayMarshaller; @@ -356,7 +361,7 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) return NoCountInfo.Instance; } - static MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation compilation, GeneratorDiagnostics diagnostics, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) + MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) { SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; @@ -434,6 +439,25 @@ static MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, Compilation // TODO: Diagnostic since no marshalling methods are supported. } + if (isContiguousCollectionMarshaller) + { + if (!ManualTypeMarshallingHelper.TryGetElementTypeFromContiguousCollectionMarshaller(nativeType, out ITypeSymbol elementType)) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + return new NativeContiguousCollectionMarshallingInfo( + nativeType, + valueProperty?.Type, + methods, + NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + UseDefaultMarshalling: !isMarshalUsingAttribute, + parsedCountInfo, + elementType, + GetMarshallingInfo(elementType, attributes, defaultInfo, compilation, diagnostics, contextSymbol, indirectionLevel + 1)); + } + return new NativeMarshallingAttributeInfo( nativeType, valueProperty?.Type, @@ -493,8 +517,8 @@ static bool TryCreateTypeBasedMarshallingInfo( } marshallingInfo = new NativeContiguousCollectionMarshallingInfo( - NativeMarshallingType: arrayMarshaller!, - ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller!)?.Type, + NativeMarshallingType: arrayMarshaller, + ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, MarshallingMethods: ~SupportedMarshallingMethods.Pinning, NativeTypePinnable: true, UseDefaultMarshalling: true, From fd837d82dd4bad1160743b338c222bcdcd0c91d5 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 11:22:41 -0700 Subject: [PATCH 11/30] Add positive unit tests for collection marshalling. --- .../CodeSnippets.cs | 98 ++++++++++++-- .../CompileFails.cs | 6 +- .../DllImportGenerator.UnitTests/Compiles.cs | 126 +++++++++++++----- .../DllImportGenerator/TypePositionInfo.cs | 4 +- 4 files changed, 186 insertions(+), 48 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index b89acc1f63db..83fd60aeeb3c 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -512,7 +512,7 @@ struct MyStruct private short s; }"; - public static string ArrayParametersAndModifiers(string elementType) => $@" + public static string MarshalAsArrayParametersAndModifiers(string elementType) => $@" using System.Runtime.InteropServices; partial class Test {{ @@ -528,9 +528,9 @@ out int pOutSize ); }}"; - public static string ArrayParametersAndModifiers() => ArrayParametersAndModifiers(typeof(T).ToString()); + public static string MarshalAsArrayParametersAndModifiers() => MarshalAsArrayParametersAndModifiers(typeof(T).ToString()); - public static string ArrayParameterWithSizeParam(string sizeParamType, bool isByRef) => $@" + public static string MarshalAsArrayParameterWithSizeParam(string sizeParamType, bool isByRef) => $@" using System.Runtime.InteropServices; partial class Test {{ @@ -541,10 +541,10 @@ public static partial void Method( ); }}"; - public static string ArrayParameterWithSizeParam(bool isByRef) => ArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); + public static string MarshalAsArrayParameterWithSizeParam(bool isByRef) => MarshalAsArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); - public static string ArrayParameterWithNestedMarshalInfo(string elementType, UnmanagedType nestedMarshalInfo) => $@" + public static string MarshalAsArrayParameterWithNestedMarshalInfo(string elementType, UnmanagedType nestedMarshalInfo) => $@" using System.Runtime.InteropServices; partial class Test {{ @@ -554,7 +554,7 @@ public static partial void Method( ); }}"; - public static string ArrayParameterWithNestedMarshalInfo(UnmanagedType nestedMarshalType) => ArrayParameterWithNestedMarshalInfo(typeof(T).ToString(), nestedMarshalType); + public static string MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType nestedMarshalType) => MarshalAsArrayParameterWithNestedMarshalInfo(typeof(T).ToString(), nestedMarshalType); public static string ArrayPreserveSigFalse(string elementType) => $@" using System.Runtime.InteropServices; @@ -915,7 +915,7 @@ struct Native } "; - public static string ArrayMarshallingWithCustomStructElementWithValueProperty => ArrayParametersAndModifiers("IntStructWrapper") + @" + public static string ArrayMarshallingWithCustomStructElementWithValueProperty => MarshalAsArrayParametersAndModifiers("IntStructWrapper") + @" [NativeMarshalling(typeof(IntStructWrapperNative))] public struct IntStructWrapper { @@ -935,7 +935,7 @@ public IntStructWrapperNative(IntStructWrapper managed) } "; - public static string ArrayMarshallingWithCustomStructElement => ArrayParametersAndModifiers("IntStructWrapper") + @" + public static string ArrayMarshallingWithCustomStructElement => MarshalAsArrayParametersAndModifiers("IntStructWrapper") + @" [NativeMarshalling(typeof(IntStructWrapperNative))] public struct IntStructWrapper { @@ -1083,5 +1083,87 @@ struct RecursiveStruct2 RecursiveStruct1 s; int i; }"; + + public static string CollectionByValue(string elementType) => BasicParameterByValue($"TestCollection<{elementType}>") + @" +[NativeMarshalling(typeof(Marshaller<>))] +class TestCollection {} + +[GenericCollectionMarshaller] +struct Marshaller +{ + public Marshaller(TestCollection managed, int nativeElementSize) {} + public System.Span ManagedValues { get; } + public System.Span NativeValueStorage { get; } + public IntPtr Value { get; } +} +"; + + public static string CollectionByValue() => CollectionByValue(typeof(T).ToString()); + + public static string MarshalUsingCollectionCountInfoParametersAndModifiers(string collectionType) => $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + [return:MarshalUsing(ConstantElementCount=10)] + public static partial {collectionType} Method( + {collectionType} p, + in {collectionType} pIn, + int pRefSize, + [MarshalUsing(CountElementName = nameof(pRefSize))] ref {collectionType} pRef, + [MarshalUsing(CountElementName = nameof(pOutSize))] out {collectionType} pOut, + out int pOutSize + ); +}}"; + + public static string MarshalUsingCollectionCountInfoParametersAndModifiers() => MarshalUsingCollectionCountInfoParametersAndModifiers(typeof(T).ToString()); + + public static string CustomCollectionDefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + @" + +[NativeMarshalling(typeof(Marshaller<>))] +class TestCollection {} + +[GenericCollectionMarshaller] +struct Marshaller +{ + public Marshaller(TestCollection managed, int nativeElementSize) {} + public System.Span ManagedValues { get; } + public System.Span NativeValueStorage { get; } + public IntPtr Value { get; set; } + public TestCollection ToManaged() => throw null; +}"; + + public static string CustomCollectionDefaultMarshallerParametersAndModifiers() => CustomCollectionDefaultMarshallerParametersAndModifiers(typeof(T).ToString()); + + public static string MarshalUsingCollectionParametersAndModifiers(string collectionType, string marshallerType) => $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + [return:MarshalUsing(typeof({marshallerType}), ConstantElementCount=10)] + public static partial {collectionType} Method( + [MarshalUsing(typeof({marshallerType})] {collectionType} p, + [MarshalUsing(typeof({marshallerType})] in {collectionType} pIn, + int pRefSize, + [MarshalUsing(typeof({marshallerType}), CountElementName = nameof(pRefSize))] ref {collectionType} pRef, + [MarshalUsing(typeof({marshallerType}), CountElementName = nameof(pOutSize))] out {collectionType} pOut, + out int pOutSize + ); +}}"; + + public static string CustomCollectionCustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + @" +class TestCollection {} + +[GenericCollectionMarshaller] +struct Marshaller +{ + public Marshaller(TestCollection managed, int nativeElementSize) {} + public System.Span ManagedValues { get; } + public System.Span NativeValueStorage { get; } + public IntPtr Value { get; set; } + public TestCollection ToManaged() => throw null; +}"; + + public static string CustomCollectionCustomMarshallerParametersAndModifiers() => CustomCollectionCustomMarshallerParametersAndModifiers(typeof(T).ToString()); } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs index 19736372a253..a0b02ce440d5 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs @@ -79,9 +79,9 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.BasicParametersAndModifiers(), 3, 0 }; // Array with non-integer size param - yield return new object[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false), 1, 0 }; - yield return new object[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false), 1, 0 }; - yield return new object[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; // Custom type marshalling with invalid members yield return new object[] { CodeSnippets.CustomStructMarshallingByRefValueProperty, 3, 0 }; diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs index a785ed75ac65..4f78d42a75c6 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs @@ -37,39 +37,39 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.BasicParametersAndModifiers() }; // Arrays - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParametersAndModifiers() }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: false) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; - yield return new[] { CodeSnippets.ArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: true) }; // CharSet yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet(CharSet.Unicode) }; @@ -89,9 +89,9 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPTStr) }; yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPUTF8Str) }; yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPStr) }; - yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo(UnmanagedType.LPWStr) }; - yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo(UnmanagedType.LPUTF8Str) }; - yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo(UnmanagedType.LPStr) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPWStr) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPUTF8Str) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPStr) }; // [In, Out] attributes // By value non-blittable array @@ -203,6 +203,60 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.ImplicitlyBlittableStructParametersAndModifiers("internal") }; yield return new[] { CodeSnippets.ImplicitlyBlittableGenericTypeParametersAndModifiers() }; yield return new[] { CodeSnippets.ImplicitlyBlittableGenericTypeParametersAndModifiers("internal") }; + + // Custom collection marshalling + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.CollectionByValue() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.MarshalUsingCollectionCountInfoParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionDefaultMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 49889ac01c69..61daadfca4d7 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -191,6 +191,7 @@ private static MarshallingInfo GetMarshallingInfo(ITypeSymbol type, IEnumerable< compilation, diagnostics, contextSymbol.ContainingType, + parsedCountInfo, indirectionLevel, out MarshallingInfo infoMaybe)) { @@ -472,6 +473,7 @@ static bool TryCreateTypeBasedMarshallingInfo( Compilation compilation, GeneratorDiagnostics diagnostics, INamedTypeSymbol scopeSymbol, + CountInfo parsedCountInfo, int indirectionLevel, out MarshallingInfo marshallingInfo) { @@ -522,7 +524,7 @@ static bool TryCreateTypeBasedMarshallingInfo( MarshallingMethods: ~SupportedMarshallingMethods.Pinning, NativeTypePinnable: true, UseDefaultMarshalling: true, - ElementCountInfo: NoCountInfo.Instance, + ElementCountInfo: parsedCountInfo, ElementType: elementType, ElementMarshallingInfo: GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel + 1)); return true; From c0ae81207e025ccc3d3c235b12ecd932a55f14a7 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 14:39:17 -0700 Subject: [PATCH 12/30] Get tests passing and update resource string. --- .../CodeSnippets.cs | 113 +++++++++++++----- .../CompileFails.cs | 15 ++- .../DllImportGenerator.UnitTests/Compiles.cs | 1 + .../ManualTypeMarshallingAnalyzer.cs | 4 +- .../ManualTypeMarshallingHelper.cs | 21 +++- .../DllImportGenerator/Resources.Designer.cs | 2 +- .../DllImportGenerator/Resources.resx | 2 +- .../DllImportGenerator/TypePositionInfo.cs | 15 +-- 8 files changed, 127 insertions(+), 46 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index 83fd60aeeb3c..8b4a5cb488ad 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -1088,13 +1088,13 @@ public static string CollectionByValue(string elementType) => BasicParameterByVa [NativeMarshalling(typeof(Marshaller<>))] class TestCollection {} -[GenericCollectionMarshaller] -struct Marshaller +[GenericContiguousCollectionMarshaller] +ref struct Marshaller { - public Marshaller(TestCollection managed, int nativeElementSize) {} + public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public System.Span ManagedValues { get; } public System.Span NativeValueStorage { get; } - public IntPtr Value { get; } + public System.IntPtr Value { get; } } "; @@ -1110,28 +1110,33 @@ partial class Test {collectionType} p, in {collectionType} pIn, int pRefSize, - [MarshalUsing(CountElementName = nameof(pRefSize))] ref {collectionType} pRef, - [MarshalUsing(CountElementName = nameof(pOutSize))] out {collectionType} pOut, + [MarshalUsing(CountElementName = ""pRefSize"")] ref {collectionType} pRef, + [MarshalUsing(CountElementName = ""pOutSize"")] out {collectionType} pOut, out int pOutSize ); }}"; + + public static string CustomCollectionWithMarshaller(bool enableDefaultMarshalling) + { + string nativeMarshallingAttribute = enableDefaultMarshalling ? "[NativeMarshalling(typeof(Marshaller<>))]" : string.Empty; + return nativeMarshallingAttribute + @"class TestCollection {} - public static string MarshalUsingCollectionCountInfoParametersAndModifiers() => MarshalUsingCollectionCountInfoParametersAndModifiers(typeof(T).ToString()); - - public static string CustomCollectionDefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + @" - -[NativeMarshalling(typeof(Marshaller<>))] -class TestCollection {} - -[GenericCollectionMarshaller] -struct Marshaller +[GenericContiguousCollectionMarshaller] +ref struct Marshaller { - public Marshaller(TestCollection managed, int nativeElementSize) {} + public Marshaller(int nativeElementSize) : this() {} + public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public System.Span ManagedValues { get; } public System.Span NativeValueStorage { get; } - public IntPtr Value { get; set; } + public System.IntPtr Value { get; set; } + public void SetUnmarshalledCollectionLength(int length) {} public TestCollection ToManaged() => throw null; }"; + } + + public static string MarshalUsingCollectionCountInfoParametersAndModifiers() => MarshalUsingCollectionCountInfoParametersAndModifiers(typeof(T).ToString()); + + public static string CustomCollectionDefaultMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionCountInfoParametersAndModifiers($"TestCollection<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); public static string CustomCollectionDefaultMarshallerParametersAndModifiers() => CustomCollectionDefaultMarshallerParametersAndModifiers(typeof(T).ToString()); @@ -1142,28 +1147,80 @@ partial class Test [GeneratedDllImport(""DoesNotExist"")] [return:MarshalUsing(typeof({marshallerType}), ConstantElementCount=10)] public static partial {collectionType} Method( - [MarshalUsing(typeof({marshallerType})] {collectionType} p, - [MarshalUsing(typeof({marshallerType})] in {collectionType} pIn, + [MarshalUsing(typeof({marshallerType}))] {collectionType} p, + [MarshalUsing(typeof({marshallerType}))] in {collectionType} pIn, int pRefSize, - [MarshalUsing(typeof({marshallerType}), CountElementName = nameof(pRefSize))] ref {collectionType} pRef, - [MarshalUsing(typeof({marshallerType}), CountElementName = nameof(pOutSize))] out {collectionType} pOut, + [MarshalUsing(typeof({marshallerType}), CountElementName = ""pRefSize"")] ref {collectionType} pRef, + [MarshalUsing(typeof({marshallerType}), CountElementName = ""pOutSize"")] out {collectionType} pOut, out int pOutSize ); }}"; - public static string CustomCollectionCustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + @" + public static string CustomCollectionCustomMarshallerParametersAndModifiers(string elementType) => MarshalUsingCollectionParametersAndModifiers($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); + + public static string CustomCollectionCustomMarshallerParametersAndModifiers() => CustomCollectionCustomMarshallerParametersAndModifiers(typeof(T).ToString()); + + public static string MarshalUsingCollectionReturnValueLength(string collectionType, string marshallerType) => $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial int Method( + [MarshalUsing(typeof({marshallerType}), CountElementName = MarshalUsingAttribute.ReturnsCountValue)] out {collectionType} pOut + ); +}}"; + + public static string CustomCollectionCustomMarshallerReturnValueLength(string elementType) => MarshalUsingCollectionReturnValueLength($"TestCollection<{elementType}>", $"Marshaller<{elementType}>") + CustomCollectionWithMarshaller(enableDefaultMarshalling: false); + + public static string CustomCollectionCustomMarshallerReturnValueLength() => CustomCollectionCustomMarshallerReturnValueLength(typeof(T).ToString()); + + public static string MarshalUsingArrayParameterWithSizeParam(string sizeParamType, bool isByRef) => $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + {(isByRef ? "ref" : "")} {sizeParamType} pRefSize, + [MarshalUsing(CountElementName = ""pRefSize"")] ref int[] pRef + ); +}}"; + + public static string MarshalUsingArrayParameterWithSizeParam(bool isByRef) => MarshalUsingArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); + + public static string MarshalUsingCollectionWithConstantAndElementCount = $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + int pRefSize, + [MarshalUsing(ConstantElementCount = 10, CountElementName = ""pRefSize"")] ref int[] pRef + ); +}}"; + + public static string MarshalUsingCollectionWithNullElementName = $@" +using System.Runtime.InteropServices; +partial class Test +{{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + int pRefSize, + [MarshalUsing(CountElementName = null)] ref int[] pRef + ); +}}"; + + public static string GenericCollectionMarshallingArityMismatch = BasicParameterByValue("TestCollection") + @" +[NativeMarshalling(typeof(Marshaller<,>))] class TestCollection {} -[GenericCollectionMarshaller] -struct Marshaller +[GenericContiguousCollectionMarshaller] +ref struct Marshaller { - public Marshaller(TestCollection managed, int nativeElementSize) {} + public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public System.Span ManagedValues { get; } public System.Span NativeValueStorage { get; } - public IntPtr Value { get; set; } + public System.IntPtr Value { get; } public TestCollection ToManaged() => throw null; }"; - - public static string CustomCollectionCustomMarshallerParametersAndModifiers() => CustomCollectionCustomMarshallerParametersAndModifiers(typeof(T).ToString()); } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs index a0b02ce440d5..d47a904d165b 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs @@ -78,10 +78,14 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.BasicParametersAndModifiers(), 3, 0 }; yield return new object[] { CodeSnippets.BasicParametersAndModifiers(), 3, 0 }; - // Array with non-integer size param + // Collection with non-integer size param yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; yield return new object[] { CodeSnippets.MarshalAsArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalUsingArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalUsingArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + yield return new object[] { CodeSnippets.MarshalUsingArrayParameterWithSizeParam(isByRef: false), 1, 0 }; + // Custom type marshalling with invalid members yield return new object[] { CodeSnippets.CustomStructMarshallingByRefValueProperty, 3, 0 }; @@ -104,6 +108,15 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.ImplicitlyBlittableStructParametersAndModifiers("public"), 5, 0 }; yield return new object[] { CodeSnippets.ImplicitlyBlittableGenericTypeParametersAndModifiers(), 5, 0 }; yield return new object[] { CodeSnippets.ImplicitlyBlittableGenericTypeParametersAndModifiers("public"), 5, 0 }; + + // Collection with constant and element size parameter + yield return new object[] { CodeSnippets.MarshalUsingCollectionWithConstantAndElementCount, 2, 0 }; + + // Collection with null element size parameter name + yield return new object[] { CodeSnippets.MarshalUsingCollectionWithNullElementName, 2, 0 }; + + // Generic collection marshaller has different arity than collection. + yield return new object[] { CodeSnippets.GenericCollectionMarshallingArityMismatch, 2, 0 }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs index 4f78d42a75c6..c2f6c66428d9 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs @@ -257,6 +257,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; + yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerReturnValueLength() }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs b/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs index 97bc59387711..68b935619852 100644 --- a/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs +++ b/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs @@ -364,9 +364,9 @@ private void AnalyzeNativeMarshalerType(SymbolAnalysisContext context, ITypeSymb continue; } - hasConstructor = hasConstructor || ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type); + hasConstructor = hasConstructor || ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, false); - if (!hasStackallocConstructor && ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, SpanOfByte)) + if (!hasStackallocConstructor && ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, SpanOfByte, false)) { hasStackallocConstructor = true; IFieldSymbol stackAllocSizeField = nativeType.GetMembers("StackBufferSize").OfType().FirstOrDefault(); diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index 31d5f143142a..2b9d466d8bb7 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -13,6 +13,7 @@ static class ManualTypeMarshallingHelper public const string ToManagedMethodName = "ToManaged"; public const string FreeNativeMethodName = "FreeNative"; public const string ManagedValuesPropertyName = "ManagedValues"; + public const string SetUnmarshalledCollectionLengthMethodName = "SetUnmarshalledCollectionLength"; public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol managedType) { @@ -28,14 +29,13 @@ public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol manage public static bool IsManagedToNativeConstructor( IMethodSymbol ctor, ITypeSymbol managedType, - ITypeSymbol int32, bool isCollectionMarshaller) { if (isCollectionMarshaller) { return ctor.Parameters.Length == 2 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) - && SymbolEqualityComparer.Default.Equals(int32, ctor.Parameters[1].Type); + && ctor.Parameters[1].Type.SpecialType == SpecialType.System_Int32; } return ctor.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type); @@ -45,7 +45,6 @@ public static bool IsStackallocConstructor( IMethodSymbol ctor, ITypeSymbol managedType, ITypeSymbol spanOfByte, - ITypeSymbol int32, bool isCollectionMarshaller) { if (isCollectionMarshaller) @@ -53,7 +52,7 @@ public static bool IsStackallocConstructor( return ctor.Parameters.Length == 3 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) && SymbolEqualityComparer.Default.Equals(spanOfByte, ctor.Parameters[1].Type) - && SymbolEqualityComparer.Default.Equals(int32, ctor.Parameters[2].Type); + && ctor.Parameters[2].Type.SpecialType == SpecialType.System_Int32; } return ctor.Parameters.Length == 2 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) @@ -83,8 +82,7 @@ public static bool HasFreeNativeMethod(ITypeSymbol type) { return type.GetMembers(FreeNativeMethodName) .OfType() - .Any(m => m is { Parameters: { Length: 0 } } and - ({ ReturnType: { SpecialType: SpecialType.System_Void } })); + .Any(m => m is { IsStatic: false, Parameters: { Length: 0 }, ReturnType: { SpecialType: SpecialType.System_Void } }); } public static IPropertySymbol? FindManagedValuesProperty(ITypeSymbol type) @@ -109,5 +107,16 @@ public static bool TryGetElementTypeFromContiguousCollectionMarshaller(ITypeSymb return true; } + public static bool HasSetUnmarshalledCollectionLengthMethod(ITypeSymbol type) + { + return type.GetMembers(SetUnmarshalledCollectionLengthMethodName) + .OfType() + .Any(m => m is + { + IsStatic: false, + Parameters: { Length: 1 }, + ReturnType: { SpecialType: SpecialType.System_Void } + } && m.Parameters[0].Type.SpecialType == SpecialType.System_Int32); + } } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Resources.Designer.cs b/DllImportGenerator/DllImportGenerator/Resources.Designer.cs index 2acea21707ca..6e9b0ab11cb1 100644 --- a/DllImportGenerator/DllImportGenerator/Resources.Designer.cs +++ b/DllImportGenerator/DllImportGenerator/Resources.Designer.cs @@ -61,7 +61,7 @@ internal Resources() { } /// - /// Looks up a localized string similar to Marshalling an array from unmanaged to managed requires either the 'SizeParamIndex' or 'SizeConst' fields to be set on a 'MarshalAsAttribute'.. + /// Looks up a localized string similar to Marshalling an array from unmanaged to managed requires either the 'SizeParamIndex' or 'SizeConst' fields to be set on a 'MarshalAsAttribute' or the 'ConstantElementCount' or 'CountElementName' properties to be set on a 'MarshalUsingAttribute'.. /// internal static string ArraySizeMustBeSpecified { get { diff --git a/DllImportGenerator/DllImportGenerator/Resources.resx b/DllImportGenerator/DllImportGenerator/Resources.resx index 698f32820bf2..f9b39dd6a923 100644 --- a/DllImportGenerator/DllImportGenerator/Resources.resx +++ b/DllImportGenerator/DllImportGenerator/Resources.resx @@ -118,7 +118,7 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - Marshalling an array from unmanaged to managed requires either the 'SizeParamIndex' or 'SizeConst' fields to be set on a 'MarshalAsAttribute'. + Marshalling an array from unmanaged to managed requires either the 'SizeParamIndex' or 'SizeConst' fields to be set on a 'MarshalAsAttribute' or the 'ConstantElementCount' or 'CountElementName' properties to be set on a 'MarshalUsingAttribute'. The 'SizeParamIndex' value in the 'MarshalAsAttribute' is out of range. diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 61daadfca4d7..f97e1e8a6a56 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -372,7 +372,6 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr } ITypeSymbol spanOfByte = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(compilation.GetSpecialType(SpecialType.System_Byte)); - ITypeSymbol int32 = compilation.GetSpecialType(SpecialType.System_Int32); INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; @@ -392,7 +391,7 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr } else { - nativeType = nativeType.Construct(namedType.TypeParameters.ToArray()); + nativeType = nativeType.ConstructedFrom.Construct(namedType.TypeArguments.ToArray()); } } else @@ -410,17 +409,17 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr bool hasInt32Constructor = false; foreach (var ctor in nativeType.Constructors) { - if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, int32, isCollectionMarshaller: true) + if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, isCollectionMarshaller: isContiguousCollectionMarshaller) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNative; } - else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, int32, isCollectionMarshaller: true) + else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, isCollectionMarshaller: isContiguousCollectionMarshaller) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; } - else if (ctor.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(ctor.Parameters[0], int32)) + else if (ctor.Parameters.Length == 1 && ctor.Parameters[0].Type.SpecialType == SpecialType.System_Int32) { hasInt32Constructor = true; } @@ -428,7 +427,8 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr // The constructor that takes only the native element size is required for collection marshallers // in the native-to-managed scenario. - if ((!isContiguousCollectionMarshaller || hasInt32Constructor) + if ((!isContiguousCollectionMarshaller + || (hasInt32Constructor && ManualTypeMarshallingHelper.HasSetUnmarshalledCollectionLengthMethod(nativeType))) && ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) && (valueProperty is null or { SetMethod: not null })) { @@ -437,7 +437,8 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr if (methods == SupportedMarshallingMethods.None) { - // TODO: Diagnostic since no marshalling methods are supported. + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; } if (isContiguousCollectionMarshaller) From 25fc6f7ae0d341c990ae5ecd9a1e9d0c80561a33 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 16:53:51 -0700 Subject: [PATCH 13/30] Add test for elementindirectionlevel processing. --- .../CodeSnippets.cs | 32 +++++++++++++++++-- .../DllImportGenerator.UnitTests/Compiles.cs | 1 + .../ManualTypeMarshallingHelper.cs | 12 ++++++- ...onBlittableElementsMarshallingGenerator.cs | 10 +++--- ...onBlittableElementsMarshallingGenerator.cs | 8 ++--- .../DllImportGenerator/TypePositionInfo.cs | 6 ++++ 6 files changed, 56 insertions(+), 13 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index 8b4a5cb488ad..f48d95f5ad96 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -1187,7 +1187,7 @@ public static partial void Method( public static string MarshalUsingArrayParameterWithSizeParam(bool isByRef) => MarshalUsingArrayParameterWithSizeParam(typeof(T).ToString(), isByRef); - public static string MarshalUsingCollectionWithConstantAndElementCount = $@" + public static string MarshalUsingCollectionWithConstantAndElementCount => $@" using System.Runtime.InteropServices; partial class Test {{ @@ -1198,7 +1198,7 @@ public static partial void Method( ); }}"; - public static string MarshalUsingCollectionWithNullElementName = $@" + public static string MarshalUsingCollectionWithNullElementName => $@" using System.Runtime.InteropServices; partial class Test {{ @@ -1209,7 +1209,7 @@ public static partial void Method( ); }}"; - public static string GenericCollectionMarshallingArityMismatch = BasicParameterByValue("TestCollection") + @" + public static string GenericCollectionMarshallingArityMismatch => BasicParameterByValue("TestCollection") + @" [NativeMarshalling(typeof(Marshaller<,>))] class TestCollection {} @@ -1222,5 +1222,31 @@ public Marshaller(TestCollection managed, int nativeElementSize) : this() {} public System.IntPtr Value { get; } public TestCollection ToManaged() => throw null; }"; + + public static string GenericCollectionWithCustomElementMarshalling => @" + +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + [return:MarshalUsing(ConstantElementCount=10)] + [return:MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] + public static partial TestCollection Method( + [MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] TestCollection p, + [MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] in TestCollection pIn, + int pRefSize, + [MarshalUsing(CountElementName = ""pRefSize""), MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] ref TestCollection pRef, + [MarshalUsing(CountElementName = ""pOutSize"")][MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] out TestCollection pOut, + out int pOutSize + ); +} + +struct IntWrapper +{ + public IntWrapper(int i){} + public int ToManaged() => throw null; +} + +" + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs index c2f6c66428d9..5f5938575561 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs @@ -258,6 +258,7 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers() }; yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerReturnValueLength() }; + yield return new[] { CodeSnippets.GenericCollectionWithCustomElementMarshalling }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index 2b9d466d8bb7..1dc51ded21f5 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -13,6 +13,7 @@ static class ManualTypeMarshallingHelper public const string ToManagedMethodName = "ToManaged"; public const string FreeNativeMethodName = "FreeNative"; public const string ManagedValuesPropertyName = "ManagedValues"; + public const string NativeValueStoragePropertyName = "NativeValueStorage"; public const string SetUnmarshalledCollectionLengthMethodName = "SetUnmarshalledCollectionLength"; public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol managedType) @@ -90,7 +91,7 @@ public static bool HasFreeNativeMethod(ITypeSymbol type) return type .GetMembers(ManagedValuesPropertyName) .OfType() - .FirstOrDefault(p => !p.IsStatic); + .FirstOrDefault(p => p is { IsStatic: false, GetMethod: not null, ReturnsByRef: false, ReturnsByRefReadonly: false }); } public static bool TryGetElementTypeFromContiguousCollectionMarshaller(ITypeSymbol type, out ITypeSymbol elementType) @@ -118,5 +119,14 @@ public static bool HasSetUnmarshalledCollectionLengthMethod(ITypeSymbol type) ReturnType: { SpecialType: SpecialType.System_Void } } && m.Parameters[0].Type.SpecialType == SpecialType.System_Int32); } + + public static bool HasNativeValueStorageProperty(ITypeSymbol type, ITypeSymbol spanOfByte) + { + return type + .GetMembers(NativeValueStoragePropertyName) + .OfType() + .Any(p => p is {IsStatic: false, GetMethod: not null, ReturnsByRef: false, ReturnsByRefReadonly: false } + && SymbolEqualityComparer.Default.Equals(p.Type, spanOfByte)); + } } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs index d6be98aee426..9905c13498df 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs @@ -43,7 +43,7 @@ public override IEnumerable GenerateIntermediateMarshallingStat MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("ManagedValues")), + IdentifierName(ManualTypeMarshallingHelper.ManagedValuesPropertyName)), IdentifierName("CopyTo"))) .AddArgumentListArguments( Argument( @@ -66,7 +66,7 @@ public override IEnumerable GenerateIntermediateMarshallingStat MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("NativeValueStorage"))))))); + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName))))))); } public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) @@ -84,7 +84,7 @@ public override IEnumerable GeneratePreUnmarshallingStatements( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("SetUnmarshalledCollectionLength"))) + IdentifierName(ManualTypeMarshallingHelper.SetUnmarshalledCollectionLengthMethodName))) .AddArgumentListArguments(Argument(numElementsExpression))); } @@ -115,14 +115,14 @@ public override IEnumerable GenerateIntermediateUnmarshallingSt MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("NativeValueStorage")))), + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName)))), IdentifierName("CopyTo"))) .AddArgumentListArguments( Argument( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("ManagedValues"))))); + IdentifierName(ManualTypeMarshallingHelper.ManagedValuesPropertyName))))); } } } \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs index 8165109a6595..d8e44cdecd8e 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs @@ -71,7 +71,7 @@ private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositi .AddArgumentListArguments( Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(GetMarshallerIdentifier(info, context)), - IdentifierName("NativeValueStorage"))))))))); + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName))))))))); } private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) @@ -85,7 +85,7 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in context); string collectionIdentifierForLength = useManagedSpanForLength - ? $"{marshalerIdentifier}.ManagedValues" + ? $"{marshalerIdentifier}.{ManualTypeMarshallingHelper.ManagedValuesPropertyName}" : nativeSpanIdentifier; TypePositionInfo localElementInfo = elementInfo with @@ -133,7 +133,7 @@ public override IEnumerable GeneratePreUnmarshallingStatements( MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, IdentifierName(marshalerIdentifier), - IdentifierName("SetUnmarshalledCollectionLength"))) + IdentifierName(ManualTypeMarshallingHelper.SetUnmarshalledCollectionLengthMethodName))) .AddArgumentListArguments(Argument(numElementsExpression))); } @@ -166,7 +166,7 @@ public override SyntaxNode VisitAssignmentExpression(AssignmentExpressionSyntax if (node.Left.ToString() == nativeIdentifier) { return node.WithRight( - CastExpression(ParseTypeName("System.IntPtr"), node.Right)); + CastExpression(MarshallerHelpers.SystemIntPtrType, node.Right)); } return node; diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index f97e1e8a6a56..6634e632fe3e 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -443,6 +443,12 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr if (isContiguousCollectionMarshaller) { + if (!ManualTypeMarshallingHelper.HasNativeValueStorageProperty(nativeType, spanOfByte)) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + if (!ManualTypeMarshallingHelper.TryGetElementTypeFromContiguousCollectionMarshaller(nativeType, out ITypeSymbol elementType)) { diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); From d2708519e0b2818106d0a8e1f519663da2497839 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 17:39:25 -0700 Subject: [PATCH 14/30] Refactor marshalling attribute parsing into its own type to enable better holistic analysis of all use-site marshalling attributes. --- .../CodeSnippets.cs | 46 +- .../CompileFails.cs | 4 + .../DllImportGenerator/DllImportStub.cs | 6 +- .../MarshallingAttributeInfo.cs | 518 ++++++++++++++++++ .../DllImportGenerator/TypePositionInfo.cs | 495 +---------------- 5 files changed, 575 insertions(+), 494 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index f48d95f5ad96..d8b24448b00e 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -1224,7 +1224,6 @@ public Marshaller(TestCollection managed, int nativeElementSize) : this() {} }"; public static string GenericCollectionWithCustomElementMarshalling => @" - using System.Runtime.InteropServices; partial class Test { @@ -1248,5 +1247,50 @@ public IntWrapper(int i){} } " + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); + + public static string GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionLevel => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + [MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] [MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 1)] TestCollection p); +} + +struct IntWrapper +{ + public IntWrapper(int i){} + public int ToManaged() => throw null; +} + +" + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); + + public static string GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionLevel => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + [MarshalUsing(typeof(IntWrapper), ElementIndirectionLevel = 2)] TestCollection p); +} + +struct IntWrapper +{ + public IntWrapper(int i){} + public int ToManaged() => throw null; +} + +" + CustomCollectionWithMarshaller(enableDefaultMarshalling: true); + + public static string MarshalAsAndMarshalUsingOnReturnValue => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + [return:MarshalUsing(ConstantElementCount=10)] + [return:MarshalAs(UnmanagedType.LPArray, SizeConst=10)] + public static partial int[] Method(); +} +"; } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs index d47a904d165b..c6fa9e66a3a9 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs @@ -117,6 +117,10 @@ public static IEnumerable CodeSnippetsToCompile() // Generic collection marshaller has different arity than collection. yield return new object[] { CodeSnippets.GenericCollectionMarshallingArityMismatch, 2, 0 }; + + yield return new object[] { CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 2, 0 }; + yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionLevel, 2, 0 }; + yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionLevel, 2, 0 }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index 52bff59b84dd..eae8ad7d79f7 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -157,12 +157,14 @@ public static DllImportStub Create( var defaultInfo = new DefaultMarshallingInfo(defaultEncoding); + var marshallingAttributeParser = new MarshallingAttributeInfoParser(env.Compilation, diagnostics, defaultInfo, method); + // Determine parameter and return types var paramsTypeInfo = new List(); for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; - var typeInfo = TypePositionInfo.CreateForParameter(param, defaultInfo, env.Compilation, diagnostics, method); + var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingAttributeParser, env.Compilation); typeInfo = typeInfo with { ManagedIndex = i, @@ -171,7 +173,7 @@ public static DllImportStub Create( paramsTypeInfo.Add(typeInfo); } - TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, method.GetReturnTypeAttributes(), defaultInfo, env.Compilation, diagnostics, method); + TypePositionInfo retTypeInfo = TypePositionInfo.CreateForType(method.ReturnType, marshallingAttributeParser.ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())); retTypeInfo = retTypeInfo with { ManagedIndex = TypePositionInfo.ReturnIndex, diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index c782ad2ed60e..5f1367e367e6 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Runtime.InteropServices; namespace Microsoft.Interop @@ -130,4 +131,521 @@ internal sealed record NativeContiguousCollectionMarshallingInfo( NativeTypePinnable, UseDefaultMarshalling ); + + class MarshallingAttributeInfoParser + { + private readonly Compilation compilation; + private readonly GeneratorDiagnostics diagnostics; + private readonly DefaultMarshallingInfo defaultInfo; + private readonly ISymbol contextSymbol; + private readonly ITypeSymbol marshalAsAttribute; + private readonly ITypeSymbol marshalUsingAttribute; + + public MarshallingAttributeInfoParser( + Compilation compilation, + GeneratorDiagnostics diagnostics, + DefaultMarshallingInfo defaultInfo, + ISymbol contextSymbol) + { + this.compilation = compilation; + this.diagnostics = diagnostics; + this.defaultInfo = defaultInfo; + this.contextSymbol = contextSymbol; + marshalAsAttribute = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)!; + marshalUsingAttribute = compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute)!; + } + + internal MarshallingInfo ParseMarshallingInfo( + ITypeSymbol managedType, + IEnumerable useSiteAttributes) + { + Dictionary marshallingAttributesByIndirectionLevel = new(); + foreach (AttributeData attribute in useSiteAttributes) + { + if (TryGetAttributeIndirectionLevel(attribute, out int indirectionLevel)) + { + if (marshallingAttributesByIndirectionLevel.ContainsKey(indirectionLevel)) + { + diagnostics.ReportConfigurationNotSupported(attribute, "Marshalling Data for Indirection Level", indirectionLevel.ToString()); + return NoMarshallingInfo.Instance; + } + marshallingAttributesByIndirectionLevel.Add(indirectionLevel, attribute); + } + } + + return GetMarshallingInfo(managedType, marshallingAttributesByIndirectionLevel); + } + + private MarshallingInfo GetMarshallingInfo(ITypeSymbol type, Dictionary useSiteAttributes, int indirectionLevel = 0) + { + CountInfo parsedCountInfo = NoCountInfo.Instance; + + if (useSiteAttributes.TryGetValue(indirectionLevel, out AttributeData useSiteAttribute)) + { + INamedTypeSymbol attributeClass = useSiteAttribute.AttributeClass!; + + if (indirectionLevel == 0 + && SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) + { + // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute + return CreateInfoFromMarshalAs(type, useSiteAttribute); + } + else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass)) + { + if (parsedCountInfo != NoCountInfo.Instance) + { + diagnostics.ReportConfigurationNotSupported(useSiteAttribute, "Duplicate Count Info"); + return NoMarshallingInfo.Instance; + } + parsedCountInfo = CreateCountInfo(useSiteAttribute); + if (useSiteAttribute.ConstructorArguments.Length != 0) + { + return CreateNativeMarshallingInfo(type, useSiteAttribute, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo, useSiteAttributes); + } + } + } + + // If we aren't overriding the marshalling at usage time, + // then fall back to the information on the element type itself. + foreach (var typeAttribute in type.GetAttributes()) + { + INamedTypeSymbol attributeClass = typeAttribute.AttributeClass!; + + if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.BlittableTypeAttribute), attributeClass)) + { + // If type is generic, then we need to re-evaluate that it is blittable at usage time. + if (type is INamedTypeSymbol { IsGenericType: false } || type.HasOnlyBlittableFields()) + { + return new BlittableTypeAttributeInfo(); + } + break; + } + else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) + { + return CreateNativeMarshallingInfo(type, typeAttribute, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo, useSiteAttributes); + } + else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) + { + return type.IsConsideredBlittable() ? new BlittableTypeAttributeInfo() : new GeneratedNativeMarshallingAttributeInfo(null! /* TODO: determine naming convention */); + } + } + + // If the type doesn't have custom attributes that dictate marshalling, + // then consider the type itself. + if (TryCreateTypeBasedMarshallingInfo( + type, + parsedCountInfo, + indirectionLevel, + useSiteAttributes, + out MarshallingInfo infoMaybe)) + { + return infoMaybe; + } + + // No marshalling info was computed, but a character encoding was provided. + // If the type is a character or string then pass on these details. + if (defaultInfo.CharEncoding != CharEncoding.Undefined + && (type.SpecialType == SpecialType.System_Char + || type.SpecialType == SpecialType.System_String)) + { + return new MarshallingInfoStringSupport(defaultInfo.CharEncoding); + } + + return NoMarshallingInfo.Instance; + } + + CountInfo CreateCountInfo(AttributeData marshalUsingData) + { + int? constSize = null; + string? elementName = null; + foreach (var arg in marshalUsingData.NamedArguments) + { + if (arg.Key == "ConstantElementCount") + { + constSize = (int)arg.Value.Value!; + } + else if (arg.Key == "CountElementName") + { + if (arg.Value.Value is null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", "null"); + return NoCountInfo.Instance; + } + elementName = (string)arg.Value.Value!; + } + } + + if (constSize is not null && elementName is not null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "ConstantElementCount and CountElementName combined"); + } + else if (constSize is not null) + { + return new ConstSizeCountInfo(constSize.Value); + } + else if (elementName is not null) + { + TypePositionInfo? elementInfo = CreateForElementName(elementName); + if (elementInfo is null) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", elementName); + return NoCountInfo.Instance; + } + return new CountElementCountInfo(elementInfo); + } + + return NoCountInfo.Instance; + } + + private TypePositionInfo? CreateForElementName(string elementName) + { + if (contextSymbol is IMethodSymbol method) + { + if (elementName == CountElementCountInfo.ReturnValueElementName) + { + return TypePositionInfo.CreateForType( + method.ReturnType, + ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())) with + { + ManagedIndex = TypePositionInfo.ReturnIndex + }; + } + + foreach (var param in method.Parameters) + { + if (param.Name == elementName) + { + return TypePositionInfo.CreateForParameter(param, this, compilation); + } + } + } + else if (contextSymbol is INamedTypeSymbol _) + { + // TODO: Handle when we create a struct marshalling generator + // Do we want to support CountElementName pointing to only fields, or properties as well? + // If only fields, how do we handle properties with generated backing fields? + } + + return null; + } + + MarshallingInfo CreateInfoFromMarshalAs( + ITypeSymbol type, + AttributeData attrData) + { + object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; + UnmanagedType unmanagedType = unmanagedTypeObj is short + ? (UnmanagedType)(short)unmanagedTypeObj + : (UnmanagedType)unmanagedTypeObj; + if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) + || unmanagedType == UnmanagedType.CustomMarshaler + || unmanagedType == UnmanagedType.SafeArray) + { + diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); + } + bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; + UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData; + SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; + + // All other data on attribute is defined as NamedArguments. + foreach (var namedArg in attrData.NamedArguments) + { + switch (namedArg.Key) + { + default: + Debug.Fail($"An unknown member was found on {nameof(MarshalAsAttribute)}"); + continue; + case nameof(MarshalAsAttribute.SafeArraySubType): + case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): + case nameof(MarshalAsAttribute.IidParameterIndex): + case nameof(MarshalAsAttribute.MarshalTypeRef): + case nameof(MarshalAsAttribute.MarshalType): + case nameof(MarshalAsAttribute.MarshalCookie): + diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + break; + case nameof(MarshalAsAttribute.ArraySubType): + if (!isArrayType) + { + diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + } + elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; + break; + case nameof(MarshalAsAttribute.SizeConst): + if (!isArrayType) + { + diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + } + arraySizeInfo = arraySizeInfo with { ConstSize = (int)namedArg.Value.Value! }; + break; + case nameof(MarshalAsAttribute.SizeParamIndex): + if (!isArrayType) + { + diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); + } + arraySizeInfo = arraySizeInfo with { ParamIndex = (short)namedArg.Value.Value! }; + break; + } + } + + if (!isArrayType) + { + return new MarshalAsInfo(unmanagedType, defaultInfo.CharEncoding); + } + + if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); + return NoMarshallingInfo.Instance; + } + + MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; + if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData) + { + elementMarshallingInfo = new MarshalAsInfo(elementUnmanagedType, defaultInfo.CharEncoding); + } + else + { + // Indirection level does not matter since we don't pass down attributes to be inspected. + elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 0); + } + + INamedTypeSymbol? arrayMarshaller; + + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + } + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays. + return NoMarshallingInfo.Instance; + } + + return new NativeContiguousCollectionMarshallingInfo( + NativeMarshallingType: arrayMarshaller, + ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, + MarshallingMethods: ~SupportedMarshallingMethods.Pinning, + NativeTypePinnable: true, + UseDefaultMarshalling: true, + ElementCountInfo: arraySizeInfo, + ElementType: elementType, + ElementMarshallingInfo: elementMarshallingInfo); + } + + MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo, Dictionary useSiteAttributes) + { + SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; + + if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) + { + methods |= SupportedMarshallingMethods.Pinning; + } + + ITypeSymbol spanOfByte = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(compilation.GetSpecialType(SpecialType.System_Byte)); + + INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; + + if (nativeType.IsUnboundGenericType) + { + if (isMarshalUsingAttribute) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + else if (type is INamedTypeSymbol namedType) + { + if (namedType.Arity != nativeType.Arity) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + else + { + nativeType = nativeType.ConstructedFrom.Construct(namedType.TypeArguments.ToArray()); + } + } + else + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + } + + ITypeSymbol contiguousCollectionMarshalerAttribute = compilation.GetTypeByMetadataName(TypeNames.GenericContiguousCollectionMarshallerAttribute)!; + + bool isContiguousCollectionMarshaller = nativeType.GetAttributes().Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, contiguousCollectionMarshalerAttribute)); + IPropertySymbol? valueProperty = ManualTypeMarshallingHelper.FindValueProperty(nativeType); + + bool hasInt32Constructor = false; + foreach (var ctor in nativeType.Constructors) + { + if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, isCollectionMarshaller: isContiguousCollectionMarshaller) + && (valueProperty is null or { GetMethod: not null })) + { + methods |= SupportedMarshallingMethods.ManagedToNative; + } + else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, isCollectionMarshaller: isContiguousCollectionMarshaller) + && (valueProperty is null or { GetMethod: not null })) + { + methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; + } + else if (ctor.Parameters.Length == 1 && ctor.Parameters[0].Type.SpecialType == SpecialType.System_Int32) + { + hasInt32Constructor = true; + } + } + + // The constructor that takes only the native element size is required for collection marshallers + // in the native-to-managed scenario. + if ((!isContiguousCollectionMarshaller + || (hasInt32Constructor && ManualTypeMarshallingHelper.HasSetUnmarshalledCollectionLengthMethod(nativeType))) + && ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) + && (valueProperty is null or { SetMethod: not null })) + { + methods |= SupportedMarshallingMethods.NativeToManaged; + } + + if (methods == SupportedMarshallingMethods.None) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + if (isContiguousCollectionMarshaller) + { + if (!ManualTypeMarshallingHelper.HasNativeValueStorageProperty(nativeType, spanOfByte)) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + if (!ManualTypeMarshallingHelper.TryGetElementTypeFromContiguousCollectionMarshaller(nativeType, out ITypeSymbol elementType)) + { + diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); + return NoMarshallingInfo.Instance; + } + + return new NativeContiguousCollectionMarshallingInfo( + nativeType, + valueProperty?.Type, + methods, + NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + UseDefaultMarshalling: !isMarshalUsingAttribute, + parsedCountInfo, + elementType, + GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1)); + } + + return new NativeMarshallingAttributeInfo( + nativeType, + valueProperty?.Type, + methods, + NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, + UseDefaultMarshalling: !isMarshalUsingAttribute); + } + + bool TryCreateTypeBasedMarshallingInfo( + ITypeSymbol type, + CountInfo parsedCountInfo, + int indirectionLevel, + Dictionary useSiteAttributes, + out MarshallingInfo marshallingInfo) + { + var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); + if (conversion.Exists + && conversion.IsImplicit + && (conversion.IsReference || conversion.IsIdentity)) + { + bool hasAccessibleDefaultConstructor = false; + if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) + { + foreach (var ctor in named.InstanceConstructors) + { + if (ctor.Parameters.Length == 0) + { + hasAccessibleDefaultConstructor = compilation.IsSymbolAccessibleWithin(ctor, contextSymbol.ContainingType); + break; + } + } + } + marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor); + return true; + } + + if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) + { + INamedTypeSymbol? arrayMarshaller; + + if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); + } + else + { + arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); + } + + if (arrayMarshaller is null) + { + // If the array marshaler type is not available, then we cannot marshal arrays. + marshallingInfo = NoMarshallingInfo.Instance; + return false; + } + + marshallingInfo = new NativeContiguousCollectionMarshallingInfo( + NativeMarshallingType: arrayMarshaller, + ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, + MarshallingMethods: ~SupportedMarshallingMethods.Pinning, + NativeTypePinnable: true, + UseDefaultMarshalling: true, + ElementCountInfo: parsedCountInfo, + ElementType: elementType, + ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1)); + return true; + } + + if (type is INamedTypeSymbol { IsValueType: true } valueType + && !valueType.IsExposedOutsideOfCurrentCompilation() + && valueType.IsConsideredBlittable()) + { + // Allow implicit [BlittableType] on internal value types. + marshallingInfo = new BlittableTypeAttributeInfo(); + return true; + } + + marshallingInfo = NoMarshallingInfo.Instance; + return false; + } + + private bool TryGetAttributeIndirectionLevel(AttributeData attrData, out int indirectionLevel) + { + if (SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, marshalAsAttribute)) + { + indirectionLevel = 0; + return true; + } + + if (!SymbolEqualityComparer.Default.Equals(attrData.AttributeClass, marshalUsingAttribute)) + { + indirectionLevel = 0; + return false; + } + + foreach (var arg in attrData.NamedArguments) + { + if (arg.Key == "ElementIndirectionLevel") + { + indirectionLevel = (int)arg.Value.Value!; + return true; + } + } + indirectionLevel = 0; + return true; + } + } } diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 6634e632fe3e..7c494108320a 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -83,9 +83,9 @@ private TypePositionInfo() public MarshallingInfo MarshallingAttributeInfo { get; init; } - public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, IMethodSymbol methodSymbol) + public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingAttributeInfoParser attributeParser, Compilation compilation) { - var marshallingInfo = GetMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes(), defaultInfo, compilation, diagnostics, methodSymbol); + var marshallingInfo = attributeParser.ParseMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes()); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, @@ -99,13 +99,12 @@ public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, return typeInfo; } - public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, ISymbol symbol) + public static TypePositionInfo CreateForType(ITypeSymbol type, MarshallingInfo marshallingInfo, string identifier = "") { - var marshallingInfo = GetMarshallingInfo(type, attributes, defaultInfo, compilation, diagnostics, symbol); var typeInfo = new TypePositionInfo() { ManagedType = type, - InstanceIdentifier = string.Empty, + InstanceIdentifier = identifier, RefKind = RefKind.None, RefKindSyntax = SyntaxKind.None, MarshallingAttributeInfo = marshallingInfo @@ -114,492 +113,6 @@ public static TypePositionInfo CreateForType(ITypeSymbol type, IEnumerable attributes, DefaultMarshallingInfo defaultInfo, Compilation compilation, GeneratorDiagnostics diagnostics, ISymbol contextSymbol, int indirectionLevel = 0) - { - CountInfo parsedCountInfo = NoCountInfo.Instance; - // Look at attributes passed in - usage specific. - foreach (var attrData in attributes) - { - INamedTypeSymbol attributeClass = attrData.AttributeClass!; - - if (indirectionLevel == 0 - && SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute), attributeClass)) - { - // https://docs.microsoft.com/dotnet/api/system.runtime.interopservices.marshalasattribute - return CreateMarshalAsInfo(type, attrData, defaultInfo, compilation); - } - else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute), attributeClass) - && AttributeAppliesToCurrentIndirectionLevel(attrData, indirectionLevel)) - { - if (parsedCountInfo != NoCountInfo.Instance) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Duplicate Count Info"); - return NoMarshallingInfo.Instance; - } - parsedCountInfo = CreateCountInfo(attrData); - if (attrData.ConstructorArguments.Length != 0) - { - return CreateNativeMarshallingInfo(type, attrData, isMarshalUsingAttribute: true, indirectionLevel, parsedCountInfo); - } - } - } - - // If we aren't overriding the marshalling at usage time, - // then fall back to the information on the element type itself. - foreach (var attrData in type.GetAttributes()) - { - INamedTypeSymbol attributeClass = attrData.AttributeClass!; - - if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.BlittableTypeAttribute), attributeClass)) - { - // If type is generic, then we need to re-evaluate that it is blittable at usage time. - if (type is INamedTypeSymbol { IsGenericType: false } || type.HasOnlyBlittableFields()) - { - return new BlittableTypeAttributeInfo(); - } - break; - } - else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.NativeMarshallingAttribute), attributeClass)) - { - return CreateNativeMarshallingInfo(type, attrData, isMarshalUsingAttribute: false, indirectionLevel, parsedCountInfo); - } - else if (SymbolEqualityComparer.Default.Equals(compilation.GetTypeByMetadataName(TypeNames.GeneratedMarshallingAttribute), attributeClass)) - { - return type.IsConsideredBlittable() ? new BlittableTypeAttributeInfo() : new GeneratedNativeMarshallingAttributeInfo(null! /* TODO: determine naming convention */); - } - } - - // If the type doesn't have custom attributes that dictate marshalling, - // then consider the type itself. - if (TryCreateTypeBasedMarshallingInfo( - type, - defaultInfo, - compilation, - diagnostics, - contextSymbol.ContainingType, - parsedCountInfo, - indirectionLevel, - out MarshallingInfo infoMaybe)) - { - return infoMaybe; - } - - // No marshalling info was computed, but a character encoding was provided. - // If the type is a character or string then pass on these details. - if (defaultInfo.CharEncoding != CharEncoding.Undefined - && (type.SpecialType == SpecialType.System_Char - || type.SpecialType == SpecialType.System_String)) - { - return new MarshallingInfoStringSupport(defaultInfo.CharEncoding); - } - - return NoMarshallingInfo.Instance; - - MarshallingInfo CreateMarshalAsInfo( - ITypeSymbol type, - AttributeData attrData, - DefaultMarshallingInfo defaultInfo, - Compilation compilation) - { - object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; - UnmanagedType unmanagedType = unmanagedTypeObj is short - ? (UnmanagedType)(short)unmanagedTypeObj - : (UnmanagedType)unmanagedTypeObj; - if (!Enum.IsDefined(typeof(UnmanagedType), unmanagedType) - || unmanagedType == UnmanagedType.CustomMarshaler - || unmanagedType == UnmanagedType.SafeArray) - { - diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - } - bool isArrayType = unmanagedType == UnmanagedType.LPArray || unmanagedType == UnmanagedType.ByValArray; - UnmanagedType elementUnmanagedType = (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData; - SizeAndParamIndexInfo arraySizeInfo = SizeAndParamIndexInfo.Unspecified; - - // All other data on attribute is defined as NamedArguments. - foreach (var namedArg in attrData.NamedArguments) - { - switch (namedArg.Key) - { - default: - Debug.Fail($"An unknown member was found on {nameof(MarshalAsAttribute)}"); - continue; - case nameof(MarshalAsAttribute.SafeArraySubType): - case nameof(MarshalAsAttribute.SafeArrayUserDefinedSubType): - case nameof(MarshalAsAttribute.IidParameterIndex): - case nameof(MarshalAsAttribute.MarshalTypeRef): - case nameof(MarshalAsAttribute.MarshalType): - case nameof(MarshalAsAttribute.MarshalCookie): - diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - break; - case nameof(MarshalAsAttribute.ArraySubType): - if (!isArrayType) - { - diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - elementUnmanagedType = (UnmanagedType)namedArg.Value.Value!; - break; - case nameof(MarshalAsAttribute.SizeConst): - if (!isArrayType) - { - diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - arraySizeInfo = arraySizeInfo with { ConstSize = (int)namedArg.Value.Value! }; - break; - case nameof(MarshalAsAttribute.SizeParamIndex): - if (!isArrayType) - { - diagnostics.ReportConfigurationNotSupported(attrData, $"{attrData.AttributeClass!.Name}{Type.Delimiter}{namedArg.Key}"); - } - arraySizeInfo = arraySizeInfo with { ParamIndex = (short)namedArg.Value.Value! }; - break; - } - } - - if (!isArrayType) - { - return new MarshalAsInfo(unmanagedType, defaultInfo.CharEncoding); - } - - if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - diagnostics.ReportConfigurationNotSupported(attrData, nameof(UnmanagedType), unmanagedType.ToString()); - return NoMarshallingInfo.Instance; - } - - MarshallingInfo elementMarshallingInfo = NoMarshallingInfo.Instance; - if (elementUnmanagedType != (UnmanagedType)SizeAndParamIndexInfo.UnspecifiedData) - { - elementMarshallingInfo = new MarshalAsInfo(elementUnmanagedType, defaultInfo.CharEncoding); - } - else - { - // Indirection level does not matter since we don't pass down attributes to be inspected. - elementMarshallingInfo = GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, contextSymbol, 0); - } - - INamedTypeSymbol? arrayMarshaller; - - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); - } - else - { - arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); - } - - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays. - return NoMarshallingInfo.Instance; - } - - return new NativeContiguousCollectionMarshallingInfo( - NativeMarshallingType: arrayMarshaller, - ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, - MarshallingMethods: ~SupportedMarshallingMethods.Pinning, - NativeTypePinnable : true, - UseDefaultMarshalling: true, - ElementCountInfo: arraySizeInfo, - ElementType: elementType, - ElementMarshallingInfo: elementMarshallingInfo); - } - - CountInfo CreateCountInfo(AttributeData marshalUsingData) - { - int? constSize = null; - string? elementName = null; - foreach (var arg in marshalUsingData.NamedArguments) - { - if (arg.Key == "ConstantElementCount") - { - constSize = (int)arg.Value.Value!; - } - else if (arg.Key == "CountElementName") - { - if (arg.Value.Value is null) - { - diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", "null"); - return NoCountInfo.Instance; - } - elementName = (string)arg.Value.Value!; - } - } - - if (constSize is not null && elementName is not null) - { - diagnostics.ReportConfigurationNotSupported(marshalUsingData, "ConstantElementCount and CountElementName combined"); - } - else if (constSize is not null) - { - return new ConstSizeCountInfo(constSize.Value); - } - else if (elementName is not null) - { - TypePositionInfo? elementInfo = CreateForElementName(compilation, diagnostics, defaultInfo, contextSymbol, elementName); - if (elementInfo is null) - { - diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", elementName); - return NoCountInfo.Instance; - } - return new CountElementCountInfo(elementInfo); - } - - return NoCountInfo.Instance; - } - - MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo) - { - SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; - - if (!isMarshalUsingAttribute && ManualTypeMarshallingHelper.FindGetPinnableReference(type) is not null) - { - methods |= SupportedMarshallingMethods.Pinning; - } - - ITypeSymbol spanOfByte = compilation.GetTypeByMetadataName(TypeNames.System_Span_Metadata)!.Construct(compilation.GetSpecialType(SpecialType.System_Byte)); - - INamedTypeSymbol nativeType = (INamedTypeSymbol)attrData.ConstructorArguments[0].Value!; - - if (nativeType.IsUnboundGenericType) - { - if (isMarshalUsingAttribute) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - else if (type is INamedTypeSymbol namedType) - { - if (namedType.Arity != nativeType.Arity) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - else - { - nativeType = nativeType.ConstructedFrom.Construct(namedType.TypeArguments.ToArray()); - } - } - else - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - } - - ITypeSymbol contiguousCollectionMarshalerAttribute = compilation.GetTypeByMetadataName(TypeNames.GenericContiguousCollectionMarshallerAttribute)!; - - bool isContiguousCollectionMarshaller = nativeType.GetAttributes().Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, contiguousCollectionMarshalerAttribute)); - IPropertySymbol? valueProperty = ManualTypeMarshallingHelper.FindValueProperty(nativeType); - - bool hasInt32Constructor = false; - foreach (var ctor in nativeType.Constructors) - { - if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, isCollectionMarshaller: isContiguousCollectionMarshaller) - && (valueProperty is null or { GetMethod: not null })) - { - methods |= SupportedMarshallingMethods.ManagedToNative; - } - else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, isCollectionMarshaller: isContiguousCollectionMarshaller) - && (valueProperty is null or { GetMethod: not null })) - { - methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; - } - else if (ctor.Parameters.Length == 1 && ctor.Parameters[0].Type.SpecialType == SpecialType.System_Int32) - { - hasInt32Constructor = true; - } - } - - // The constructor that takes only the native element size is required for collection marshallers - // in the native-to-managed scenario. - if ((!isContiguousCollectionMarshaller - || (hasInt32Constructor && ManualTypeMarshallingHelper.HasSetUnmarshalledCollectionLengthMethod(nativeType))) - && ManualTypeMarshallingHelper.HasToManagedMethod(nativeType, type) - && (valueProperty is null or { SetMethod: not null })) - { - methods |= SupportedMarshallingMethods.NativeToManaged; - } - - if (methods == SupportedMarshallingMethods.None) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - if (isContiguousCollectionMarshaller) - { - if (!ManualTypeMarshallingHelper.HasNativeValueStorageProperty(nativeType, spanOfByte)) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - if (!ManualTypeMarshallingHelper.TryGetElementTypeFromContiguousCollectionMarshaller(nativeType, out ITypeSymbol elementType)) - { - diagnostics.ReportConfigurationNotSupported(attrData, "Native Type", nativeType.ToDisplayString()); - return NoMarshallingInfo.Instance; - } - - return new NativeContiguousCollectionMarshallingInfo( - nativeType, - valueProperty?.Type, - methods, - NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, - UseDefaultMarshalling: !isMarshalUsingAttribute, - parsedCountInfo, - elementType, - GetMarshallingInfo(elementType, attributes, defaultInfo, compilation, diagnostics, contextSymbol, indirectionLevel + 1)); - } - - return new NativeMarshallingAttributeInfo( - nativeType, - valueProperty?.Type, - methods, - NativeTypePinnable: ManualTypeMarshallingHelper.FindGetPinnableReference(nativeType) is not null, - UseDefaultMarshalling: !isMarshalUsingAttribute); - } - - static bool TryCreateTypeBasedMarshallingInfo( - ITypeSymbol type, - DefaultMarshallingInfo defaultInfo, - Compilation compilation, - GeneratorDiagnostics diagnostics, - INamedTypeSymbol scopeSymbol, - CountInfo parsedCountInfo, - int indirectionLevel, - out MarshallingInfo marshallingInfo) - { - var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); - if (conversion.Exists - && conversion.IsImplicit - && (conversion.IsReference || conversion.IsIdentity)) - { - bool hasAccessibleDefaultConstructor = false; - if (type is INamedTypeSymbol named && !named.IsAbstract && named.InstanceConstructors.Length > 0) - { - foreach (var ctor in named.InstanceConstructors) - { - if (ctor.Parameters.Length == 0) - { - hasAccessibleDefaultConstructor = compilation.IsSymbolAccessibleWithin(ctor, scopeSymbol); - break; - } - } - } - marshallingInfo = new SafeHandleMarshallingInfo(hasAccessibleDefaultConstructor); - return true; - } - - if (type is IArrayTypeSymbol { ElementType: ITypeSymbol elementType }) - { - INamedTypeSymbol? arrayMarshaller; - - if (elementType is IPointerTypeSymbol { PointedAtType: ITypeSymbol pointedAt }) - { - arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_PtrArrayMarshaller_Metadata)?.Construct(pointedAt); - } - else - { - arrayMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_GeneratedMarshalling_ArrayMarshaller_Metadata)?.Construct(elementType); - } - - if (arrayMarshaller is null) - { - // If the array marshaler type is not available, then we cannot marshal arrays. - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - - marshallingInfo = new NativeContiguousCollectionMarshallingInfo( - NativeMarshallingType: arrayMarshaller, - ValuePropertyType: ManualTypeMarshallingHelper.FindValueProperty(arrayMarshaller)?.Type, - MarshallingMethods: ~SupportedMarshallingMethods.Pinning, - NativeTypePinnable: true, - UseDefaultMarshalling: true, - ElementCountInfo: parsedCountInfo, - ElementType: elementType, - ElementMarshallingInfo: GetMarshallingInfo(elementType, Array.Empty(), defaultInfo, compilation, diagnostics, scopeSymbol, indirectionLevel + 1)); - return true; - } - - if (type is INamedTypeSymbol { IsValueType: true } valueType - && !valueType.IsExposedOutsideOfCurrentCompilation() - && valueType.IsConsideredBlittable()) - { - // Allow implicit [BlittableType] on internal value types. - marshallingInfo = new BlittableTypeAttributeInfo(); - return true; - } - - marshallingInfo = NoMarshallingInfo.Instance; - return false; - } - } - - private static TypePositionInfo? CreateForElementName(Compilation compilation, GeneratorDiagnostics diagnostics, DefaultMarshallingInfo defaultInfo, ISymbol context, string elementName) - { - if (context is IMethodSymbol method) - { - if (elementName == CountElementCountInfo.ReturnValueElementName) - { - return CreateForType( - method.ReturnType, - method.GetReturnTypeAttributes(), - defaultInfo, - compilation, - diagnostics, - method) with - { - ManagedIndex = ReturnIndex - }; - } - - foreach (var param in method.Parameters) - { - if (param.Name == elementName) - { - return CreateForParameter(param, defaultInfo, compilation, diagnostics, method); - } - } - } - else if (context is INamedTypeSymbol _) - { - // TODO: Handle when we create a struct marshalling generator - // Do we want to support CountElementName pointing to only fields, or properties as well? - // If only fields, how do we handle properties with generated backing fields? - } - - return null; - } - - private static bool AttributeAppliesToCurrentIndirectionLevel(AttributeData attrData, int indirectionLevel) - { - int elementIndirectionLevel = 0; - foreach (var arg in attrData.NamedArguments) - { - if (arg.Key == "ElementIndirectionLevel") - { - elementIndirectionLevel = (int)arg.Value.Value!; - } - } - return elementIndirectionLevel == indirectionLevel; - } - private static ByValueContentsMarshalKind GetByValueContentsMarshalKind(IEnumerable attributes, Compilation compilation) { INamedTypeSymbol outAttributeType = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_OutAttribute)!; From 2ebb57f408addb0392d4f534546abf97facf7cb0 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 18:13:29 -0700 Subject: [PATCH 15/30] Add recursion protection for CountElementName and marshalling attribute usage validation for ElementIndirectionLevel. --- .../CodeSnippets.cs | 32 +++++ .../CompileFails.cs | 5 +- .../DllImportGenerator/DllImportStub.cs | 2 +- .../MarshallingAttributeInfo.cs | 121 +++++++++++++++--- .../DllImportGenerator/TypePositionInfo.cs | 12 +- 5 files changed, 141 insertions(+), 31 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs index d8b24448b00e..21297ad381d3 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs @@ -1291,6 +1291,38 @@ partial class Test [return:MarshalAs(UnmanagedType.LPArray, SizeConst=10)] public static partial int[] Method(); } +"; + + public static string RecursiveCountElementNameOnReturnValue => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + [return:MarshalUsing(CountElementName=MarshalUsingAttribute.ReturnsCountValue)] + public static partial int[] Method(); +} +"; + + public static string RecursiveCountElementNameOnParameter => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + [MarshalUsing(CountElementName=""arr"")] ref int[] arr + ); +} +"; + public static string MutuallyRecursiveCountElementNameOnParameter => @" +using System.Runtime.InteropServices; +partial class Test +{ + [GeneratedDllImport(""DoesNotExist"")] + public static partial void Method( + [MarshalUsing(CountElementName=""arr2"")] ref int[] arr, + [MarshalUsing(CountElementName=""arr"")] ref int[] arr2 + ); +} "; } } diff --git a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs index c6fa9e66a3a9..d34b33c63813 100644 --- a/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs +++ b/DllImportGenerator/DllImportGenerator.UnitTests/CompileFails.cs @@ -120,7 +120,10 @@ public static IEnumerable CodeSnippetsToCompile() yield return new object[] { CodeSnippets.MarshalAsAndMarshalUsingOnReturnValue, 2, 0 }; yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingDuplicateElementIndirectionLevel, 2, 0 }; - yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionLevel, 2, 0 }; + yield return new object[] { CodeSnippets.GenericCollectionWithCustomElementMarshallingUnusedElementIndirectionLevel, 1, 0 }; + yield return new object[] { CodeSnippets.RecursiveCountElementNameOnReturnValue, 2, 0 }; + yield return new object[] { CodeSnippets.RecursiveCountElementNameOnParameter, 2, 0 }; + yield return new object[] { CodeSnippets.MutuallyRecursiveCountElementNameOnParameter, 4, 0 }; } [Theory] diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index eae8ad7d79f7..336c3f7dd377 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -164,7 +164,7 @@ public static DllImportStub Create( for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; - var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingAttributeParser, env.Compilation); + var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()), env.Compilation); typeInfo = typeInfo with { ManagedIndex = i, diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index 5f1367e367e6..522614eb60c7 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -1,14 +1,23 @@ using Microsoft.CodeAnalysis; using System; using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.Linq; using System.Runtime.InteropServices; namespace Microsoft.Interop { + + /// + /// Type used to pass on default marshalling details. + /// + internal sealed record DefaultMarshallingInfo( + CharEncoding CharEncoding + ); + // The following types are modeled to fit with the current prospective spec - // for C# 10 discriminated unions. Once discriminated unions are released, + // for C# vNext discriminated unions. Once discriminated unions are released, // these should be updated to be implemented as a discriminated union. internal abstract record MarshallingInfo @@ -173,11 +182,52 @@ internal MarshallingInfo ParseMarshallingInfo( } } - return GetMarshallingInfo(managedType, marshallingAttributesByIndirectionLevel); + return ParseMarshallingInfo(managedType, useSiteAttributes, ImmutableHashSet.Empty); + } + + private MarshallingInfo ParseMarshallingInfo( + ITypeSymbol managedType, + IEnumerable useSiteAttributes, + ImmutableHashSet inspectedElements) + { + Dictionary marshallingAttributesByIndirectionLevel = new(); + int maxIndirectionLevelDataProvided = 0; + foreach (AttributeData attribute in useSiteAttributes) + { + if (TryGetAttributeIndirectionLevel(attribute, out int indirectionLevel)) + { + if (marshallingAttributesByIndirectionLevel.ContainsKey(indirectionLevel)) + { + diagnostics.ReportConfigurationNotSupported(attribute, "Marshalling Data for Indirection Level", indirectionLevel.ToString()); + return NoMarshallingInfo.Instance; + } + marshallingAttributesByIndirectionLevel.Add(indirectionLevel, attribute); + maxIndirectionLevelDataProvided = Math.Max(maxIndirectionLevelDataProvided, indirectionLevel); + } + } + + int maxIndirectionLevelUsed = 0; + MarshallingInfo info = GetMarshallingInfo( + managedType, + marshallingAttributesByIndirectionLevel, + indirectionLevel: 0, + inspectedElements, + ref maxIndirectionLevelUsed); + if (maxIndirectionLevelUsed < maxIndirectionLevelDataProvided) + { + diagnostics.ReportConfigurationNotSupported(marshallingAttributesByIndirectionLevel[maxIndirectionLevelDataProvided], "ElementIndirectionLevel", maxIndirectionLevelDataProvided.ToString()); + } + return info; } - private MarshallingInfo GetMarshallingInfo(ITypeSymbol type, Dictionary useSiteAttributes, int indirectionLevel = 0) + private MarshallingInfo GetMarshallingInfo( + ITypeSymbol type, + Dictionary useSiteAttributes, + int indirectionLevel, + ImmutableHashSet inspectedElements, + ref int maxIndirectionLevelUsed) { + maxIndirectionLevelUsed = Math.Max(indirectionLevel, maxIndirectionLevelUsed); CountInfo parsedCountInfo = NoCountInfo.Instance; if (useSiteAttributes.TryGetValue(indirectionLevel, out AttributeData useSiteAttribute)) @@ -188,7 +238,7 @@ private MarshallingInfo GetMarshallingInfo(ITypeSymbol type, Dictionary inspectedElements) { int? constSize = null; string? elementName = null; @@ -285,7 +353,13 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) } else if (elementName is not null) { - TypePositionInfo? elementInfo = CreateForElementName(elementName); + if (inspectedElements.Contains(elementName)) + { + diagnostics.ReportConfigurationNotSupported(marshalUsingData, "Cyclical CountElementName"); + return NoCountInfo.Instance; + } + + TypePositionInfo? elementInfo = CreateForElementName(elementName, inspectedElements.Add(elementName)); if (elementInfo is null) { diagnostics.ReportConfigurationNotSupported(marshalUsingData, "CountElementName", elementName); @@ -297,7 +371,7 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) return NoCountInfo.Instance; } - private TypePositionInfo? CreateForElementName(string elementName) + private TypePositionInfo? CreateForElementName(string elementName, ImmutableHashSet inspectedElements) { if (contextSymbol is IMethodSymbol method) { @@ -305,7 +379,7 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) { return TypePositionInfo.CreateForType( method.ReturnType, - ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes())) with + ParseMarshallingInfo(method.ReturnType, method.GetReturnTypeAttributes(), inspectedElements)) with { ManagedIndex = TypePositionInfo.ReturnIndex }; @@ -315,7 +389,7 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) { if (param.Name == elementName) { - return TypePositionInfo.CreateForParameter(param, this, compilation); + return TypePositionInfo.CreateForParameter(param, ParseMarshallingInfo(param.Type, param.GetAttributes(), inspectedElements), compilation); } } } @@ -331,7 +405,8 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData) MarshallingInfo CreateInfoFromMarshalAs( ITypeSymbol type, - AttributeData attrData) + AttributeData attrData, + ref int maxIndirectionLevelUsed) { object unmanagedTypeObj = attrData.ConstructorArguments[0].Value!; UnmanagedType unmanagedType = unmanagedTypeObj is short @@ -405,8 +480,8 @@ MarshallingInfo CreateInfoFromMarshalAs( } else { - // Indirection level does not matter since we don't pass down attributes to be inspected. - elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 0); + maxIndirectionLevelUsed = 1; + elementMarshallingInfo = GetMarshallingInfo(elementType, new Dictionary(), 1, ImmutableHashSet.Empty, ref maxIndirectionLevelUsed); } INamedTypeSymbol? arrayMarshaller; @@ -437,7 +512,15 @@ MarshallingInfo CreateInfoFromMarshalAs( ElementMarshallingInfo: elementMarshallingInfo); } - MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attrData, bool isMarshalUsingAttribute, int indirectionLevel, CountInfo parsedCountInfo, Dictionary useSiteAttributes) + MarshallingInfo CreateNativeMarshallingInfo( + ITypeSymbol type, + AttributeData attrData, + bool isMarshalUsingAttribute, + int indirectionLevel, + CountInfo parsedCountInfo, + Dictionary useSiteAttributes, + ImmutableHashSet inspectedElements, + ref int maxIndirectionLevelUsed) { SupportedMarshallingMethods methods = SupportedMarshallingMethods.None; @@ -538,7 +621,7 @@ MarshallingInfo CreateNativeMarshallingInfo(ITypeSymbol type, AttributeData attr UseDefaultMarshalling: !isMarshalUsingAttribute, parsedCountInfo, elementType, - GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1)); + GetMarshallingInfo(elementType, useSiteAttributes, maxIndirectionLevelUsed = indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); } return new NativeMarshallingAttributeInfo( @@ -554,6 +637,8 @@ bool TryCreateTypeBasedMarshallingInfo( CountInfo parsedCountInfo, int indirectionLevel, Dictionary useSiteAttributes, + ImmutableHashSet inspectedElements, + ref int maxIndirectionLevelUsed, out MarshallingInfo marshallingInfo) { var conversion = compilation.ClassifyCommonConversion(type, compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_SafeHandle)!); @@ -605,7 +690,7 @@ bool TryCreateTypeBasedMarshallingInfo( UseDefaultMarshalling: true, ElementCountInfo: parsedCountInfo, ElementType: elementType, - ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1)); + ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, maxIndirectionLevelUsed = indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); return true; } diff --git a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs index 7c494108320a..0f011d9f0846 100644 --- a/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs +++ b/DllImportGenerator/DllImportGenerator/TypePositionInfo.cs @@ -1,8 +1,5 @@ using System; using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Runtime.InteropServices; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -10,12 +7,6 @@ namespace Microsoft.Interop { - /// - /// Type used to pass on default marshalling details. - /// - internal sealed record DefaultMarshallingInfo ( - CharEncoding CharEncoding - ); /// /// Describes how to marshal the contents of a value in comparison to the value itself. @@ -83,9 +74,8 @@ private TypePositionInfo() public MarshallingInfo MarshallingAttributeInfo { get; init; } - public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingAttributeInfoParser attributeParser, Compilation compilation) + public static TypePositionInfo CreateForParameter(IParameterSymbol paramSymbol, MarshallingInfo marshallingInfo, Compilation compilation) { - var marshallingInfo = attributeParser.ParseMarshallingInfo(paramSymbol.Type, paramSymbol.GetAttributes()); var typeInfo = new TypePositionInfo() { ManagedType = paramSymbol.Type, From a745e84d7f2bb1169efedd3eb1c1ccffc320e4d5 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 1 Jun 2021 19:17:07 -0700 Subject: [PATCH 16/30] Use ImmutableHashSet to enable fast lookup and automatic "pop" behavior for recursion protection. --- .../DllImportGenerator/TypeSymbolExtensions.cs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/TypeSymbolExtensions.cs b/DllImportGenerator/DllImportGenerator/TypeSymbolExtensions.cs index 4495704b3dcc..aff8b91be2df 100644 --- a/DllImportGenerator/DllImportGenerator/TypeSymbolExtensions.cs +++ b/DllImportGenerator/DllImportGenerator/TypeSymbolExtensions.cs @@ -13,9 +13,9 @@ namespace Microsoft.Interop { static class TypeSymbolExtensions { - public static bool HasOnlyBlittableFields(this ITypeSymbol type) => HasOnlyBlittableFields(type, new HashSet(SymbolEqualityComparer.Default)); + public static bool HasOnlyBlittableFields(this ITypeSymbol type) => HasOnlyBlittableFields(type, ImmutableHashSet.Create(SymbolEqualityComparer.Default)); - private static bool HasOnlyBlittableFields(this ITypeSymbol type, HashSet seenTypes) + private static bool HasOnlyBlittableFields(this ITypeSymbol type, ImmutableHashSet seenTypes) { if (seenTypes.Contains(type)) { @@ -24,7 +24,7 @@ private static bool HasOnlyBlittableFields(this ITypeSymbol type, HashSet()) { if (!field.IsStatic) @@ -39,18 +39,16 @@ private static bool HasOnlyBlittableFields(this ITypeSymbol type, HashSet true, { Type: { IsValueType: false } } => false, - _ => IsConsideredBlittable(field.Type, seenTypes) + _ => IsConsideredBlittable(field.Type, seenTypes.Add(type)) }; if (!fieldBlittable) { - seenTypes.Remove(type); return false; } } } - seenTypes.Remove(type); return true; } @@ -72,9 +70,9 @@ or SpecialType.System_IntPtr _ => false }; - public static bool IsConsideredBlittable(this ITypeSymbol type) => IsConsideredBlittable(type, new HashSet(SymbolEqualityComparer.Default)); + public static bool IsConsideredBlittable(this ITypeSymbol type) => IsConsideredBlittable(type, ImmutableHashSet.Create(SymbolEqualityComparer.Default)); - private static bool IsConsideredBlittable(this ITypeSymbol type, HashSet seenTypes) + private static bool IsConsideredBlittable(this ITypeSymbol type, ImmutableHashSet seenTypes) { if (type.SpecialType != SpecialType.None) { From d9a93087b57d0cee8933aed352201b17d1834656 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 2 Jun 2021 10:36:51 -0700 Subject: [PATCH 17/30] Add tests with basic collection marshalling. --- .../CollectionTests.cs | 232 ++++++++++++++++++ .../TestAssets/SharedTypes/NonBlittable.cs | 112 +++++++++ 2 files changed, 344 insertions(+) create mode 100644 DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs b/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs new file mode 100644 index 000000000000..b5ba088da565 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs @@ -0,0 +1,232 @@ +using SharedTypes; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +using Xunit; + +namespace DllImportGenerator.IntegrationTests +{ + partial class NativeExportsNE + { + public partial class Collections + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum([MarshalUsing(typeof(ListMarshaller))] List values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum(ref int values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] + public static partial int SumInArray([MarshalUsing(typeof(ListMarshaller))] in List values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] + public static partial void Duplicate([MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] ref List values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array")] + [return:MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] + public static partial List CreateRange(int start, int end, out int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] + public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] out List res); + + //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] + //public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray); + + //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")] + //public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] ref List strArray, out int numElements); + + //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] + //[return: MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] + //public static partial List ReverseStrings_Return([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray, out int numElements); + + //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] + //public static partial void ReverseStrings_Out( + // [MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray, + // out int numElements, + // [MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] out List res); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return:MarshalUsing(typeof(ListMarshaller), ConstantElementCount = sizeof(long))] + public static partial List GetLongBytes(long l); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")] + [return:MarshalAs(UnmanagedType.U1)] + public static partial bool AndAllMembers([MarshalUsing(typeof(ListMarshaller))] List pArray, int length); + } + } + + public class CollectionTests + { + [Fact] + public void BlittableElementColllectionMarshalledToNativeAsExpected() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Collections.Sum(list, list.Count)); + } + + [Fact] + public void NullBlittableElementColllectionMarshalledToNativeAsExpected() + { + Assert.Equal(-1, NativeExportsNE.Collections.Sum(null, 0)); + } + + [Fact] + public void BlittableElementColllectionInParameter() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Collections.SumInArray(list, list.Count)); + } + + [Fact] + public void BlittableElementCollectionRefParameter() + { + var list = new List { 1, 5, 79, 165, 32, 3 }; + var newList = list; + NativeExportsNE.Collections.Duplicate(ref newList, list.Count); + Assert.Equal((IEnumerable)list, newList); + } + + [Fact] + public void BlittableElementCollectionReturnedFromNative() + { + int start = 5; + int end = 20; + + IEnumerable expected = Enumerable.Range(start, end - start); + Assert.Equal(expected, NativeExportsNE.Collections.CreateRange(start, end, out _)); + + List res; + NativeExportsNE.Collections.CreateRange_Out(start, end, out _, out res); + Assert.Equal(expected, res); + } + + [Fact] + public void NullBlittableElementCollectionReturnedFromNative() + { + Assert.Null(NativeExportsNE.Collections.CreateRange(1, 0, out _)); + + List res; + NativeExportsNE.Collections.CreateRange_Out(1, 0, out _, out res); + Assert.Null(res); + } + + private static List GetStringList() + { + return new() + { + "ABCdef 123$%^", + "🍜 !! 🍜 !!", + "🌲 木 🔥 火 🌾 土 🛡 金 🌊 水" , + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed vitae posuere mauris, sed ultrices leo. Suspendisse potenti. Mauris enim enim, blandit tincidunt consequat in, varius sit amet neque. Morbi eget porttitor ex. Duis mattis aliquet ante quis imperdiet. Duis sit.", + string.Empty, + null + }; + } + + //[Fact] + //public void ByValueCollectionWithNonBlittableElements() + //{ + // var strings = GetStringList(); + // Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings)); + //} + + //[Fact] + //public void ByValueNullCollectionWithNonBlittableElements() + //{ + // Assert.Equal(0, NativeExportsNE.Collections.SumStringLengths(null)); + //} + + //[Fact] + //public void ByRefCollectionWithNonBlittableElements() + //{ + // var strings = GetStringList(); + // var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); + // NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); + + // Assert.Equal((IEnumerable)expectedStrings, strings); + //} + + //[Fact] + //public void ReturnCollectionWithNonBlittableElements() + //{ + // var strings = GetStringList(); + // var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); + // Assert.Equal(expectedStrings, NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); + + // List res; + // NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); + // Assert.Equal(expectedStrings, res); + //} + + //[Fact] + //public void ByRefNullCollectionWithNonBlittableElements() + //{ + // List strings = null; + // NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); + + // Assert.Null(strings); + //} + + //[Fact] + //public void ReturnNullCollectionWithNonBlittableElements() + //{ + // List strings = null; + // Assert.Null(NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); + + // List res; + // NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); + // Assert.Null(res); + //} + + [Fact] + public void ConstantSizeCollection() + { + var longVal = 0x12345678ABCDEF10L; + + Assert.Equal(longVal, MemoryMarshal.Read(CollectionsMarshal.AsSpan(NativeExportsNE.Collections.GetLongBytes(longVal)))); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CollectionWithSimpleNonBlittableTypeMarshalling(bool result) + { + var boolValues = new List + { + new BoolStruct + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct + { + b1 = true, + b2 = true, + b3 = result, + }, + }; + + Assert.Equal(result, NativeExportsNE.Collections.AndAllMembers(boolValues, boolValues.Count)); + } + + private static string ReverseChars(string value) + { + if (value == null) + return null; + + var chars = value.ToCharArray(); + Array.Reverse(chars); + return new string(chars); + } + } +} diff --git a/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs b/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs index aee951b165ea..44ed63bfad56 100644 --- a/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs +++ b/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.InteropServices; namespace SharedTypes @@ -217,4 +219,114 @@ public IntStructWrapperNative(IntStructWrapper managed) public IntStructWrapper ToManaged() => new IntStructWrapper { Value = value }; } + + [GenericContiguousCollectionMarshaller] + public unsafe ref struct ListMarshaller + { + private List managedList; + private readonly int sizeOfNativeElement; + private IntPtr allocatedMemory; + + public ListMarshaller(int sizeOfNativeElement) + : this() + { + this.sizeOfNativeElement = sizeOfNativeElement; + } + + public ListMarshaller(List managed, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedList = null; + NativeValueStorage = default; + return; + } + managedList = managed; + this.sizeOfNativeElement = sizeOfNativeElement; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Count * sizeOfNativeElement, 1); + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); + } + + public ListMarshaller(List managed, Span stackSpace, int sizeOfNativeElement) + { + allocatedMemory = default; + this.sizeOfNativeElement = sizeOfNativeElement; + if (managed is null) + { + managedList = null; + NativeValueStorage = default; + return; + } + managedList = managed; + // Always allocate at least one byte when the array is zero-length. + int spaceToAllocate = Math.Max(managed.Count * sizeOfNativeElement, 1); + if (spaceToAllocate < stackSpace.Length) + { + NativeValueStorage = stackSpace[0..spaceToAllocate]; + } + else + { + allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)allocatedMemory, spaceToAllocate); + } + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of array parameters doesn't + /// blow the stack since this is a new optimization in the code-generated interop. + /// + public const int StackBufferSize = 0x200; + + public Span ManagedValues => CollectionsMarshal.AsSpan(managedList); + + public Span NativeValueStorage { get; private set; } + + public ref byte GetPinnableReference() => ref NativeValueStorage.GetPinnableReference(); + + public void SetUnmarshalledCollectionLength(int length) + { + managedList = new List(length); + for (int i = 0; i < length; i++) + { + managedList.Add(default); + } + } + + public byte* Value + { + get + { + Debug.Assert(managedList is null || allocatedMemory != IntPtr.Zero); + return (byte*)allocatedMemory; + } + set + { + if (value == null) + { + managedList = null; + NativeValueStorage = default; + } + else + { + allocatedMemory = (IntPtr)value; + NativeValueStorage = new Span(value, (managedList?.Count ?? 0) * sizeOfNativeElement); + } + } + } + + public List ToManaged() => managedList; + + public void FreeNative() + { + if (allocatedMemory != IntPtr.Zero) + { + Marshal.FreeCoTaskMem(allocatedMemory); + } + } + } } From 9a350c59fe73f77918c8bf9b7a5e59fe6e434a17 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Wed, 2 Jun 2021 10:38:13 -0700 Subject: [PATCH 18/30] Remove commented out code. --- .../CollectionTests.cs | 71 ------------------- 1 file changed, 71 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs b/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs index b5ba088da565..a5be7cf5d1c6 100644 --- a/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/CollectionTests.cs @@ -32,22 +32,6 @@ public partial class Collections [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(ListMarshaller), CountElementName = "numValues")] out List res); - //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] - //public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray); - - //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")] - //public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] ref List strArray, out int numElements); - - //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_return")] - //[return: MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] - //public static partial List ReverseStrings_Return([MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray, out int numElements); - - //[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_out")] - //public static partial void ReverseStrings_Out( - // [MarshalUsing(typeof(ListMarshaller)), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] List strArray, - // out int numElements, - // [MarshalUsing(typeof(ListMarshaller), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaler), ElementIndirectionLevel = 1)] out List res); - [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] [return:MarshalUsing(typeof(ListMarshaller), ConstantElementCount = sizeof(long))] public static partial List GetLongBytes(long l); @@ -125,61 +109,6 @@ private static List GetStringList() null }; } - - //[Fact] - //public void ByValueCollectionWithNonBlittableElements() - //{ - // var strings = GetStringList(); - // Assert.Equal(strings.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings)); - //} - - //[Fact] - //public void ByValueNullCollectionWithNonBlittableElements() - //{ - // Assert.Equal(0, NativeExportsNE.Collections.SumStringLengths(null)); - //} - - //[Fact] - //public void ByRefCollectionWithNonBlittableElements() - //{ - // var strings = GetStringList(); - // var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); - // NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); - - // Assert.Equal((IEnumerable)expectedStrings, strings); - //} - - //[Fact] - //public void ReturnCollectionWithNonBlittableElements() - //{ - // var strings = GetStringList(); - // var expectedStrings = strings.Select(s => ReverseChars(s)).ToList(); - // Assert.Equal(expectedStrings, NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); - - // List res; - // NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); - // Assert.Equal(expectedStrings, res); - //} - - //[Fact] - //public void ByRefNullCollectionWithNonBlittableElements() - //{ - // List strings = null; - // NativeExportsNE.Collections.ReverseStrings_Ref(ref strings, out _); - - // Assert.Null(strings); - //} - - //[Fact] - //public void ReturnNullCollectionWithNonBlittableElements() - //{ - // List strings = null; - // Assert.Null(NativeExportsNE.Collections.ReverseStrings_Return(strings, out _)); - - // List res; - // NativeExportsNE.Collections.ReverseStrings_Out(strings, out _, out res); - // Assert.Null(res); - //} [Fact] public void ConstantSizeCollection() From 7d2a90d6eb7a776f2ad8c30cf3be95d75e0a5de7 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 3 Jun 2021 10:43:30 -0700 Subject: [PATCH 19/30] Generate intermediate cleanup statements even if there is no FreeNative method on a marshaller type.. --- .../Marshalling/CustomNativeTypeMarshaller.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs index 4f6674d4ff67..0c42f0643d4d 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs @@ -254,12 +254,12 @@ public virtual IEnumerable Generate(TypePositionInfo info, Stub } break; case StubCodeContext.Stage.Cleanup: + foreach (var statement in GenerateIntermediateCleanupStatements(info, context)) + { + yield return statement; + } if (_hasFreeNative) { - foreach (var statement in GenerateIntermediateCleanupStatements(info, context)) - { - yield return statement; - } // .FreeNative(); yield return ExpressionStatement( InvocationExpression( From 992cca4d89c9ba797c72d6cf692ca58f8c8b69d7 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 4 Jun 2021 17:02:17 -0700 Subject: [PATCH 20/30] Add a nice big comment block explaining the GetIdentifiers logic around managed/native return positions. --- .../DllImportGenerator/StubCodeGenerator.cs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs index c1c318cc4302..15746deb0cbf 100644 --- a/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/StubCodeGenerator.cs @@ -94,12 +94,23 @@ public StubCodeGenerator( public override (string managed, string native) GetIdentifiers(TypePositionInfo info) { + // If the info is in the managed return position, then we need to generate a name to use + // for both the managed and native values since there is no name in the signature for the return value. if (info.IsManagedReturnPosition) { return (ReturnIdentifier, ReturnNativeIdentifier); } - else if (!info.IsManagedReturnPosition && info.IsNativeReturnPosition) - { + // If the info is in the native return position but is not in the managed return position, + // then that means that the stub is introducing an additional info for the return position. + // This means that there is no name in source for this info, so we must provide one here. + // We can't use ReturnIdentifier or ReturnNativeIdentifier since that will be used by the managed return value. + // Additionally, since all use cases today of a TypePositionInfo in the native position but not the managed + // are for infos that aren't in the managed signature at all (PreserveSig scenario), we don't have a name + // that we can use from source. As a result, we generate another name for the native return value + // and use the same name for native and managed. + else if (info.IsNativeReturnPosition) + { + Debug.Assert(info.ManagedIndex == TypePositionInfo.UnsetIndex); return (InvokeReturnIdentifier, InvokeReturnIdentifier); } else From 99b8f743f28bb3da30a91897cbd804f62ec23a7c Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 4 Jun 2021 17:05:16 -0700 Subject: [PATCH 21/30] Fix edge case for stackalloc size check. --- DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs | 4 ++-- DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs index 4710fb8337cb..370e90efb3d2 100644 --- a/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs +++ b/DllImportGenerator/Ancillary.Interop/ArrayMarshaller.cs @@ -47,7 +47,7 @@ public ArrayMarshaller(T[]? managed, Span stackSpace, int sizeOfNativeElem managedArray = managed; // Always allocate at least one byte when the array is zero-length. int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); - if (spaceToAllocate < stackSpace.Length) + if (spaceToAllocate <= stackSpace.Length) { NativeValueStorage = stackSpace[0..spaceToAllocate]; } @@ -152,7 +152,7 @@ public PtrArrayMarshaller(T*[]? managed, Span stackSpace, int sizeOfNative managedArray = managed; // Always allocate at least one byte when the array is zero-length. int spaceToAllocate = Math.Max(managed.Length * sizeOfNativeElement, 1); - if (spaceToAllocate < stackSpace.Length) + if (spaceToAllocate <= stackSpace.Length) { NativeValueStorage = stackSpace[0..spaceToAllocate]; } diff --git a/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs b/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs index 44ed63bfad56..35275afa0916 100644 --- a/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs +++ b/DllImportGenerator/TestAssets/SharedTypes/NonBlittable.cs @@ -264,7 +264,7 @@ public ListMarshaller(List managed, Span stackSpace, int sizeOfNativeEl managedList = managed; // Always allocate at least one byte when the array is zero-length. int spaceToAllocate = Math.Max(managed.Count * sizeOfNativeElement, 1); - if (spaceToAllocate < stackSpace.Length) + if (spaceToAllocate <= stackSpace.Length) { NativeValueStorage = stackSpace[0..spaceToAllocate]; } From d8829e57040b2d845ded5a8868f8f7011a1b4e81 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 4 Jun 2021 17:11:52 -0700 Subject: [PATCH 22/30] Introduce local for easier debuggability. --- DllImportGenerator/DllImportGenerator/DllImportStub.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/DllImportGenerator/DllImportGenerator/DllImportStub.cs b/DllImportGenerator/DllImportGenerator/DllImportStub.cs index 336c3f7dd377..f74f9b2a3cff 100644 --- a/DllImportGenerator/DllImportGenerator/DllImportStub.cs +++ b/DllImportGenerator/DllImportGenerator/DllImportStub.cs @@ -164,7 +164,8 @@ public static DllImportStub Create( for (int i = 0; i < method.Parameters.Length; i++) { var param = method.Parameters[i]; - var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()), env.Compilation); + MarshallingInfo marshallingInfo = marshallingAttributeParser.ParseMarshallingInfo(param.Type, param.GetAttributes()); + var typeInfo = TypePositionInfo.CreateForParameter(param, marshallingInfo, env.Compilation); typeInfo = typeInfo with { ManagedIndex = i, From d1cd18562b83738f9867270a24ce2014fb50374e Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 4 Jun 2021 17:11:59 -0700 Subject: [PATCH 23/30] Introduce constants. --- .../ManualTypeMarshallingHelper.cs | 7 ++++ .../MarshallingAttributeInfo.cs | 34 ++++++------------- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index 1dc51ded21f5..ab9000d94bad 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -16,6 +16,13 @@ static class ManualTypeMarshallingHelper public const string NativeValueStoragePropertyName = "NativeValueStorage"; public const string SetUnmarshalledCollectionLengthMethodName = "SetUnmarshalledCollectionLength"; + public static class MarshalUsingProperties + { + public const string ElementIndirectionLevel = "ElementIndirectionLevel"; + public const string CountElementName = "CountElementName"; + public const string ConstantElementCount = "ConstantElementCount"; + } + public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol managedType) { return nativeType.GetMembers(ToManagedMethodName) diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index 1b4b098f6b4d..47a481ab324a 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -141,7 +141,7 @@ internal sealed record NativeContiguousCollectionMarshallingInfo( UseDefaultMarshalling ); - class MarshallingAttributeInfoParser + internal class MarshallingAttributeInfoParser { private readonly Compilation compilation; private readonly GeneratorDiagnostics diagnostics; @@ -164,24 +164,10 @@ public MarshallingAttributeInfoParser( marshalUsingAttribute = compilation.GetTypeByMetadataName(TypeNames.MarshalUsingAttribute)!; } - internal MarshallingInfo ParseMarshallingInfo( + public MarshallingInfo ParseMarshallingInfo( ITypeSymbol managedType, IEnumerable useSiteAttributes) { - Dictionary marshallingAttributesByIndirectionLevel = new(); - foreach (AttributeData attribute in useSiteAttributes) - { - if (TryGetAttributeIndirectionLevel(attribute, out int indirectionLevel)) - { - if (marshallingAttributesByIndirectionLevel.ContainsKey(indirectionLevel)) - { - diagnostics.ReportConfigurationNotSupported(attribute, "Marshalling Data for Indirection Level", indirectionLevel.ToString()); - return NoMarshallingInfo.Instance; - } - marshallingAttributesByIndirectionLevel.Add(indirectionLevel, attribute); - } - } - return ParseMarshallingInfo(managedType, useSiteAttributes, ImmutableHashSet.Empty); } @@ -215,7 +201,7 @@ private MarshallingInfo ParseMarshallingInfo( ref maxIndirectionLevelUsed); if (maxIndirectionLevelUsed < maxIndirectionLevelDataProvided) { - diagnostics.ReportConfigurationNotSupported(marshallingAttributesByIndirectionLevel[maxIndirectionLevelDataProvided], "ElementIndirectionLevel", maxIndirectionLevelDataProvided.ToString()); + diagnostics.ReportConfigurationNotSupported(marshallingAttributesByIndirectionLevel[maxIndirectionLevelDataProvided], ManualTypeMarshallingHelper.MarshalUsingProperties.ElementIndirectionLevel, maxIndirectionLevelDataProvided.ToString()); } return info; } @@ -328,15 +314,15 @@ CountInfo CreateCountInfo(AttributeData marshalUsingData, ImmutableHashSet Date: Fri, 4 Jun 2021 17:12:34 -0700 Subject: [PATCH 24/30] Seal attributes. --- .../Ancillary.Interop/GeneratedMarshallingAttribute.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs index 6939ffcf568b..b8ff26d48e90 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs @@ -2,17 +2,17 @@ namespace System.Runtime.InteropServices { [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct)] - class GeneratedMarshallingAttribute : Attribute + sealed class GeneratedMarshallingAttribute : Attribute { } [AttributeUsage(AttributeTargets.Struct)] - public class BlittableTypeAttribute : Attribute + public sealed class BlittableTypeAttribute : Attribute { } [AttributeUsage(AttributeTargets.Struct | AttributeTargets.Class)] - public class NativeMarshallingAttribute : Attribute + public sealed class NativeMarshallingAttribute : Attribute { public NativeMarshallingAttribute(Type nativeType) { @@ -23,7 +23,7 @@ public NativeMarshallingAttribute(Type nativeType) } [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.ReturnValue | AttributeTargets.Field, AllowMultiple = true)] - public class MarshalUsingAttribute : Attribute + public sealed class MarshalUsingAttribute : Attribute { public MarshalUsingAttribute() { From e2a15acbe4881c5446d1a5f227af613e4c378f37 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 7 Jun 2021 10:55:52 -0700 Subject: [PATCH 25/30] string.Empty instead of null. --- .../Ancillary.Interop/GeneratedMarshallingAttribute.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs index b8ff26d48e90..dd605b9e60bb 100644 --- a/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs +++ b/DllImportGenerator/Ancillary.Interop/GeneratedMarshallingAttribute.cs @@ -27,7 +27,7 @@ public sealed class MarshalUsingAttribute : Attribute { public MarshalUsingAttribute() { - CountElementName = null!; + CountElementName = string.Empty; } public MarshalUsingAttribute(Type nativeType) From 594aa686e5cff8e1a4dc9cf3891bf839aa362347 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 7 Jun 2021 17:21:04 -0700 Subject: [PATCH 26/30] Use a combination of the Strategy and Decoration patterns to implement the custom type and collection marshalling generation in a less confusing albeit more verbose way. --- .../Marshalling/ArrayMarshaller.cs | 234 ++--- ...onBlittableElementsMarshallingGenerator.cs | 128 --- ...onBlittableElementsMarshallingGenerator.cs | 176 ---- .../Marshalling/CustomNativeTypeMarshaller.cs | 324 ------ .../CustomNativeTypeMarshallingGenerator.cs | 83 ++ .../ICustomNativeTypeMarshallingStrategy.cs | 969 ++++++++++++++++++ .../Marshalling/MarshallingGenerator.cs | 118 ++- .../PinnableManagedValueMarshaller.cs | 84 ++ 8 files changed, 1335 insertions(+), 781 deletions(-) delete mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs delete mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs delete mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs create mode 100644 DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index 7649919fbfb1..c5838517bb3d 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -1,167 +1,141 @@ -using System; -using System.Collections.Generic; - -using Microsoft.CodeAnalysis; +using System.Collections.Generic; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Microsoft.Interop { - internal class ArrayMarshaller : CustomNativeTypeMarshaller + internal class ArrayMarshaller : IMarshallingGenerator { - private readonly CustomNativeTypeMarshaller innerCollectionMarshaller; - private readonly bool blittable; - - public ArrayMarshaller( - ContiguousCollectionBlittableElementsMarshallingGenerator innerCollectionMarshaller, - NativeContiguousCollectionMarshallingInfo marshallingInfo) - : base(marshallingInfo) - { - this.innerCollectionMarshaller = innerCollectionMarshaller; - blittable = true; - } + private readonly IMarshallingGenerator manualMarshallingGenerator; + private readonly TypeSyntax elementType; + private readonly bool enablePinning; - public ArrayMarshaller( - ContiguousCollectionNonBlittableElementsMarshallingGenerator innerCollectionMarshaller, - NativeContiguousCollectionMarshallingInfo marshallingInfo) - : base(marshallingInfo) + public ArrayMarshaller(IMarshallingGenerator manualMarshallingGenerator, TypeSyntax elementType, bool enablePinning) { - this.innerCollectionMarshaller = innerCollectionMarshaller; - blittable = false; + this.manualMarshallingGenerator = manualMarshallingGenerator; + this.elementType = elementType; + this.enablePinning = enablePinning; } - private bool UseCustomPinningPath(TypePositionInfo info, StubCodeContext context) + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) { - return blittable && !info.IsByRef && !info.IsManagedReturnPosition && context.PinningSupported; - } - - public override IEnumerable Generate(TypePositionInfo info, StubCodeContext context) - { - if (UseCustomPinningPath(info, context)) - { - return GenerateCustomPinning(); - } - - if (context.CurrentStage == StubCodeContext.Stage.Unmarshal - && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (IsPinningPathSupported(info, context)) { - // For [Out] by value unmarshalling, we emit custom code that only copies the elements. - // We do not call SetUnmarshalledCollectionLength since that creates a new - // array, and we want to fill the original one. - return innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context); - } - - return innerCollectionMarshaller.Generate(info, context); - - IEnumerable GenerateCustomPinning() - { - var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); - string byRefIdentifier = $"__byref_{managedIdentifer}"; - TypeSyntax arrayElementType = ((IArrayTypeSymbol)info.ManagedType).ElementType.AsTypeSyntax(); - if (context.CurrentStage == StubCodeContext.Stage.Marshal) - { - // [COMPAT] We use explicit byref calculations here instead of just using a fixed statement - // since a fixed statement converts a zero-length array to a null pointer. - // Many native APIs, such as GDI+, ICU, etc. validate that an array parameter is non-null - // even when the passed in array length is zero. To avoid breaking customers that want to move - // to source-generated interop in subtle ways, we explicitly pass a reference to the 0-th element - // of an array as long as it is non-null, matching the behavior of the built-in interop system - // for single-dimensional zero-based arrays. - - // ref = == null ? ref *(); - var nullRef = - PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, - CastExpression( - PointerType(arrayElementType), - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))); - - var getArrayDataReference = - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), - IdentifierName("GetArrayDataReference")), - ArgumentList(SingletonSeparatedList( - Argument(IdentifierName(managedIdentifer))))); - - yield return LocalDeclarationStatement( - VariableDeclaration( - RefType(arrayElementType)) - .WithVariables(SingletonSeparatedList( - VariableDeclarator(Identifier(byRefIdentifier)) - .WithInitializer(EqualsValueClause( - RefExpression(ParenthesizedExpression( - ConditionalExpression( - BinaryExpression( - SyntaxKind.EqualsExpression, - IdentifierName(managedIdentifer), - LiteralExpression( - SyntaxKind.NullLiteralExpression)), - RefExpression(nullRef), - RefExpression(getArrayDataReference))))))))); - } - if (context.CurrentStage == StubCodeContext.Stage.Pin) - { - // fixed ( = &) - yield return FixedStatement( - VariableDeclaration(AsNativeType(info), SingletonSeparatedList( - VariableDeclarator(nativeIdentifier) - .WithInitializer(EqualsValueClause( - PrefixUnaryExpression(SyntaxKind.AddressOfExpression, - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_CompilerServices_Unsafe), - GenericName("As").AddTypeArgumentListArguments( - arrayElementType, - PredefinedType(Token(SyntaxKind.ByteKeyword))))) - .AddArgumentListArguments( - Argument(IdentifierName(byRefIdentifier)) - .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))))), - EmptyStatement()); - } - yield break; + string identifier = context.GetIdentifiers(info).native; + return Argument(CastExpression(AsNativeType(info), IdentifierName(identifier))); } + return manualMarshallingGenerator.AsArgument(info, context); } - public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + public TypeSyntax AsNativeType(TypePositionInfo info) { - return innerCollectionMarshaller.GenerateAdditionalNativeTypeConstructorArguments(info, context); + return manualMarshallingGenerator.AsNativeType(info); } - public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) + public ParameterSyntax AsParameter(TypePositionInfo info) { - if (info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) - { - // Don't marshal contents of an array when it is marshalled by value [Out]. - return Array.Empty(); - } - return innerCollectionMarshaller.GenerateIntermediateMarshallingStatements(info, context); + return manualMarshallingGenerator.AsParameter(info); } - public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { - return innerCollectionMarshaller.GeneratePreUnmarshallingStatements(info, context); + if (IsPinningPathSupported(info, context)) + { + return GeneratePinningPath(info, context); + } + return manualMarshallingGenerator.Generate(info, context); } - public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) + public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { - return innerCollectionMarshaller.GenerateIntermediateUnmarshallingStatements(info, context); + if (!(context.PinningSupported && enablePinning)) + { + return marshalKind.HasFlag(ByValueContentsMarshalKind.Out); + } + return manualMarshallingGenerator.SupportsByValueMarshalKind(marshalKind, context); } - public override IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { - return innerCollectionMarshaller.GenerateIntermediateCleanupStatements(info, context); + if (IsPinningPathSupported(info, context)) + { + return false; + } + return manualMarshallingGenerator.UsesNativeIdentifier(info, context); } - public override bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) + private bool IsPinningPathSupported(TypePositionInfo info, StubCodeContext context) { - return !(blittable && context.PinningSupported) && marshalKind.HasFlag(ByValueContentsMarshalKind.Out); + return context.PinningSupported && enablePinning && !info.IsByRef && !info.IsManagedReturnPosition; } - public override bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + private IEnumerable GeneratePinningPath(TypePositionInfo info, StubCodeContext context) { - return !UseCustomPinningPath(info, context) && innerCollectionMarshaller.UsesNativeIdentifier(info, context); + var (managedIdentifer, nativeIdentifier) = context.GetIdentifiers(info); + string byRefIdentifier = $"__byref_{managedIdentifer}"; + TypeSyntax arrayElementType = elementType; + if (context.CurrentStage == StubCodeContext.Stage.Marshal) + { + // [COMPAT] We use explicit byref calculations here instead of just using a fixed statement + // since a fixed statement converts a zero-length array to a null pointer. + // Many native APIs, such as GDI+, ICU, etc. validate that an array parameter is non-null + // even when the passed in array length is zero. To avoid breaking customers that want to move + // to source-generated interop in subtle ways, we explicitly pass a reference to the 0-th element + // of an array as long as it is non-null, matching the behavior of the built-in interop system + // for single-dimensional zero-based arrays. + + // ref = == null ? ref *(); + var nullRef = + PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, + CastExpression( + PointerType(arrayElementType), + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))); + + var getArrayDataReference = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + IdentifierName("GetArrayDataReference")), + ArgumentList(SingletonSeparatedList( + Argument(IdentifierName(managedIdentifer))))); + + yield return LocalDeclarationStatement( + VariableDeclaration( + RefType(arrayElementType)) + .WithVariables(SingletonSeparatedList( + VariableDeclarator(Identifier(byRefIdentifier)) + .WithInitializer(EqualsValueClause( + RefExpression(ParenthesizedExpression( + ConditionalExpression( + BinaryExpression( + SyntaxKind.EqualsExpression, + IdentifierName(managedIdentifer), + LiteralExpression( + SyntaxKind.NullLiteralExpression)), + RefExpression(nullRef), + RefExpression(getArrayDataReference))))))))); + } + if (context.CurrentStage == StubCodeContext.Stage.Pin) + { + // fixed ( = &Unsafe.As(ref )) + yield return FixedStatement( + VariableDeclaration(AsNativeType(info), SingletonSeparatedList( + VariableDeclarator(nativeIdentifier) + .WithInitializer(EqualsValueClause( + PrefixUnaryExpression(SyntaxKind.AddressOfExpression, + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_CompilerServices_Unsafe), + GenericName("As").AddTypeArgumentListArguments( + arrayElementType, + PredefinedType(Token(SyntaxKind.ByteKeyword))))) + .AddArgumentListArguments( + Argument(IdentifierName(byRefIdentifier)) + .WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))))), + EmptyStatement()); + } } } -} \ No newline at end of file +} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs deleted file mode 100644 index 9905c13498df..000000000000 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionBlittableElementsMarshallingGenerator.cs +++ /dev/null @@ -1,128 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - class ContiguousCollectionBlittableElementsMarshallingGenerator : CustomNativeTypeMarshaller - { - private readonly ITypeSymbol elementType; - private readonly ExpressionSyntax numElementsExpression; - - public ContiguousCollectionBlittableElementsMarshallingGenerator( - NativeContiguousCollectionMarshallingInfo marshallingInfo, - ExpressionSyntax numElementsExpression) - :base(marshallingInfo) - { - this.elementType = marshallingInfo.ElementType; - this.numElementsExpression = numElementsExpression; - } - - private ExpressionSyntax GenerateSizeOfElementExpression() - { - return SizeOfExpression(elementType.AsTypeSyntax()); - } - - public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) - { - yield return Argument(GenerateSizeOfElementExpression()); - } - - public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - // .ManagedValues.CopyTo(MemoryMarshal.Cast>( GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) - { - yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(marshalerIdentifier), - ImplicitObjectCreationExpression() - .AddArgumentListArguments(Argument(GenerateSizeOfElementExpression())))); - } - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.SetUnmarshalledCollectionLengthMethodName))) - .AddArgumentListArguments(Argument(numElementsExpression))); - } - - public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - // MemoryMarshal.Cast>(.ManagedValues); - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), - GenericName( - Identifier("Cast")) - .WithTypeArgumentList( - TypeArgumentList( - SeparatedList( - new [] - { - PredefinedType(Token(SyntaxKind.ByteKeyword)), - elementType.AsTypeSyntax() - }))))) - .AddArgumentListArguments( - Argument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName)))), - IdentifierName("CopyTo"))) - .AddArgumentListArguments( - Argument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.ManagedValuesPropertyName))))); - } - } -} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs deleted file mode 100644 index d8e44cdecd8e..000000000000 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ContiguousCollectionNonBlittableElementsMarshallingGenerator.cs +++ /dev/null @@ -1,176 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - class ContiguousCollectionNonBlittableElementsMarshallingGenerator : CustomNativeTypeMarshaller - { - private const string IndexerIdentifier = "__i"; - private readonly IMarshallingGenerator elementMarshaller; - private readonly TypePositionInfo elementInfo; - private readonly ExpressionSyntax numElementsExpression; - - public ContiguousCollectionNonBlittableElementsMarshallingGenerator( - NativeContiguousCollectionMarshallingInfo marshallingInfo, - IMarshallingGenerator elementMarshaller, - TypePositionInfo elementInfo, - ExpressionSyntax numElementsExpression) - :base(marshallingInfo) - { - this.elementMarshaller = elementMarshaller; - this.elementInfo = elementInfo; - this.numElementsExpression = numElementsExpression; - } - - private ExpressionSyntax GenerateSizeOfElementExpression() - { - return SizeOfExpression(elementMarshaller.AsNativeType(elementInfo)); - } - - public override IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) - { - yield return Argument(GenerateSizeOfElementExpression()); - } - - private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext context) - { - return context.GetIdentifiers(info).managed + "__nativeSpan"; - } - - private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositionInfo info, StubCodeContext context) - { - string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); - return LocalDeclarationStatement(VariableDeclaration( - GenericName( - Identifier(TypeNames.System_Span), - TypeArgumentList( - SingletonSeparatedList(elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax())) - ), - SingletonSeparatedList( - VariableDeclarator(Identifier(nativeSpanIdentifier)) - .WithInitializer(EqualsValueClause( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), - GenericName( - Identifier("Cast")) - .WithTypeArgumentList( - TypeArgumentList( - SeparatedList( - new [] - { - PredefinedType(Token(SyntaxKind.ByteKeyword)), - elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax() - }))))) - .AddArgumentListArguments( - Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(GetMarshallerIdentifier(info, context)), - IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName))))))))); - } - - private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) - { - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); - var elementSubContext = new ContiguousCollectionElementMarshallingCodeContext( - context.CurrentStage, - IndexerIdentifier, - nativeSpanIdentifier, - context); - - string collectionIdentifierForLength = useManagedSpanForLength - ? $"{marshalerIdentifier}.{ManualTypeMarshallingHelper.ManagedValuesPropertyName}" - : nativeSpanIdentifier; - - TypePositionInfo localElementInfo = elementInfo with - { - InstanceIdentifier = info.InstanceIdentifier, - RefKind = info.IsByRef ? info.RefKind : info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind(), - ManagedIndex = info.ManagedIndex, - NativeIndex = info.NativeIndex - }; - - StatementSyntax marshallingStatement = Block( - List(elementMarshaller.Generate( - localElementInfo, - elementSubContext))); - - if (elementMarshaller.AsNativeType(elementInfo) is PointerTypeSyntax) - { - PointerNativeTypeAssignmentRewriter rewriter = new(elementSubContext.GetIdentifiers(localElementInfo).native); - marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement); - } - - // Iterate through the elements of the native collection to unmarshal them - return Block( - GenerateNativeSpanDeclaration(info, context), - MarshallerHelpers.GetForLoop(collectionIdentifierForLength, IndexerIdentifier) - .WithStatement(marshallingStatement)); - } - - public override IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); - } - - public override IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) - { - yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName(marshalerIdentifier), - ImplicitObjectCreationExpression().AddArgumentListArguments(Argument(GenerateSizeOfElementExpression())))); - } - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.SetUnmarshalledCollectionLengthMethodName))) - .AddArgumentListArguments(Argument(numElementsExpression))); - } - - public override IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); - } - - public override IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) - { - yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); - } - - /// - /// Rewrite assignment expressions to the native identifier to cast to IntPtr. - /// This handles the case where the native type of a non-blittable managed type is a pointer, - /// which are unsupported in generic type parameters. - /// - private class PointerNativeTypeAssignmentRewriter : CSharpSyntaxRewriter - { - private readonly string nativeIdentifier; - - public PointerNativeTypeAssignmentRewriter(string nativeIdentifier) - { - this.nativeIdentifier = nativeIdentifier; - } - - public override SyntaxNode VisitAssignmentExpression(AssignmentExpressionSyntax node) - { - if (node.Left.ToString() == nativeIdentifier) - { - return node.WithRight( - CastExpression(MarshallerHelpers.SystemIntPtrType, node.Right)); - } - - return node; - } - } - } -} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs deleted file mode 100644 index 0c42f0643d4d..000000000000 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshaller.cs +++ /dev/null @@ -1,324 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Microsoft.Interop -{ - class CustomNativeTypeMarshaller : IMarshallingGenerator - { - private readonly TypeSyntax _nativeTypeSyntax; - private readonly TypeSyntax _nativeLocalTypeSyntax; - private readonly SupportedMarshallingMethods _marshallingMethods; - private readonly bool _hasFreeNative; - private readonly bool _useValueProperty; - private readonly bool _marshalerTypePinnable; - - public CustomNativeTypeMarshaller(NativeMarshallingAttributeInfo marshallingInfo) - { - ITypeSymbol nativeType = marshallingInfo.ValuePropertyType ?? marshallingInfo.NativeMarshallingType; - _nativeTypeSyntax = ParseTypeName(nativeType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - _nativeLocalTypeSyntax = ParseTypeName(marshallingInfo.NativeMarshallingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - _marshallingMethods = marshallingInfo.MarshallingMethods; - _hasFreeNative = ManualTypeMarshallingHelper.HasFreeNativeMethod(marshallingInfo.NativeMarshallingType); - _useValueProperty = marshallingInfo.ValuePropertyType != null; - _marshalerTypePinnable = marshallingInfo.NativeTypePinnable; - } - - public CustomNativeTypeMarshaller(GeneratedNativeMarshallingAttributeInfo marshallingInfo) - { - _nativeTypeSyntax = _nativeLocalTypeSyntax = ParseTypeName(marshallingInfo.NativeMarshallingFullyQualifiedTypeName); - _marshallingMethods = SupportedMarshallingMethods.ManagedToNative | SupportedMarshallingMethods.NativeToManaged; - _hasFreeNative = true; - _useValueProperty = false; - _marshalerTypePinnable = false; - } - - public string GetMarshallerIdentifier(TypePositionInfo info, StubCodeContext context) - { - return _useValueProperty - ? MarshallerHelpers.GetMarshallerIdentifier(info, context) - : context.GetIdentifiers(info).native; - } - - public TypeSyntax AsNativeType(TypePositionInfo info) - { - return _nativeTypeSyntax; - } - - public ParameterSyntax AsParameter(TypePositionInfo info) - { - var type = info.IsByRef - ? PointerType(AsNativeType(info)) - : AsNativeType(info); - return Parameter(Identifier(info.InstanceIdentifier)) - .WithType(type); - } - - public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) - { - string identifier = context.GetIdentifiers(info).native; - if (info.IsByRef) - { - return Argument( - PrefixUnaryExpression( - SyntaxKind.AddressOfExpression, - IdentifierName(identifier))); - } - - if (context.PinningSupported && (_marshallingMethods & SupportedMarshallingMethods.Pinning) != 0) - { - return Argument(CastExpression(AsNativeType(info), IdentifierName(identifier))); - } - - return Argument(IdentifierName(identifier)); - } - - public virtual IEnumerable Generate(TypePositionInfo info, StubCodeContext context) - { - (string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info); - string marshalerIdentifier = GetMarshallerIdentifier(info, context); - if (!info.IsManagedReturnPosition - && !info.IsByRef - && context.PinningSupported - && (_marshallingMethods & SupportedMarshallingMethods.Pinning) != 0) - { - if (context.CurrentStage == StubCodeContext.Stage.Pin) - { - yield return FixedStatement( - VariableDeclaration( - PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))), - SingletonSeparatedList( - VariableDeclarator(Identifier(nativeIdentifier)) - .WithInitializer(EqualsValueClause( - IdentifierName(managedIdentifier) - )) - ) - ), - EmptyStatement() - ); - } - yield break; - } - - switch (context.CurrentStage) - { - case StubCodeContext.Stage.Setup: - if (_useValueProperty) - { - yield return LocalDeclarationStatement( - VariableDeclaration( - _nativeLocalTypeSyntax, - SingletonSeparatedList( - VariableDeclarator(marshalerIdentifier) - .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression)))))); - } - break; - case StubCodeContext.Stage.Marshal: - if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) - { - // Stack space must be usable and the marshaler must support stackalloc to use stackalloc. - // We also require pinning to be supported to enable users to pass the stackalloc'd Span - // to native code by having the marshaler type return a byref to the Span's elements - // in its GetPinnableReference method. - bool scenarioSupportsStackalloc = context.StackSpaceUsable - && (_marshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0 - && context.PinningSupported; - - List arguments = new List - { - Argument(IdentifierName(managedIdentifier)) - }; - - if (scenarioSupportsStackalloc && (!info.IsByRef || info.RefKind == RefKind.In)) - { - string stackallocIdentifier = $"{managedIdentifier}__stackptr"; - // byte* __stackptr = stackalloc byte[<_nativeLocalType>.StackBufferSize]; - yield return LocalDeclarationStatement( - VariableDeclaration( - PointerType(PredefinedType(Token(SyntaxKind.ByteKeyword))), - SingletonSeparatedList( - VariableDeclarator(stackallocIdentifier) - .WithInitializer(EqualsValueClause( - StackAllocArrayCreationExpression( - ArrayType( - PredefinedType(Token(SyntaxKind.ByteKeyword)), - SingletonList(ArrayRankSpecifier(SingletonSeparatedList( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - _nativeLocalTypeSyntax, - IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName)) - )))))))))); - - // new Span(__stackptr, <_nativeLocalType>.StackBufferSize) - arguments.Add(Argument( - ObjectCreationExpression( - GenericName(Identifier(TypeNames.System_Span), - TypeArgumentList(SingletonSeparatedList( - PredefinedType(Token(SyntaxKind.ByteKeyword)))))) - .WithArgumentList( - ArgumentList(SeparatedList(new ArgumentSyntax[] - { - Argument(IdentifierName(stackallocIdentifier)), - Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - _nativeLocalTypeSyntax, - IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName))) - }))))); - } - - arguments.AddRange(GenerateAdditionalNativeTypeConstructorArguments(info, context)); - - // = new <_nativeLocalType>(); - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(marshalerIdentifier), - ObjectCreationExpression(_nativeLocalTypeSyntax) - .WithArgumentList(ArgumentList(SeparatedList(arguments))))); - - foreach (var statement in GenerateIntermediateMarshallingStatements(info, context)) - { - yield return statement; - } - - bool skipValueProperty = _marshalerTypePinnable && (!info.IsByRef || info.RefKind == RefKind.In); - - if (_useValueProperty && !skipValueProperty) - { - // = .Value; - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(nativeIdentifier), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)))); - } - } - break; - case StubCodeContext.Stage.Pin: - if (_marshalerTypePinnable && (!info.IsByRef || info.RefKind == RefKind.In)) - { - // fixed (<_nativeTypeSyntax> = &) - yield return FixedStatement( - VariableDeclaration( - _nativeTypeSyntax, - SingletonSeparatedList( - VariableDeclarator(nativeIdentifier) - .WithInitializer(EqualsValueClause( - PrefixUnaryExpression(SyntaxKind.AddressOfExpression, - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.GetPinnableReferenceName)), - ArgumentList())))))), - EmptyStatement()); - } - break; - case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) - { - foreach (var statement in GeneratePreUnmarshallingStatements(info, context)) - { - yield return statement; - } - - if (_useValueProperty) - { - // .Value = ; - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), - IdentifierName(nativeIdentifier))); - } - - foreach (var statement in GenerateIntermediateUnmarshallingStatements(info, context)) - { - yield return statement; - } - - // = .ToManaged(); - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(managedIdentifier), - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.ToManagedMethodName))))); - } - break; - case StubCodeContext.Stage.Cleanup: - foreach (var statement in GenerateIntermediateCleanupStatements(info, context)) - { - yield return statement; - } - if (_hasFreeNative) - { - // .FreeNative(); - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshalerIdentifier), - IdentifierName(ManualTypeMarshallingHelper.FreeNativeMethodName)))); - } - break; - // TODO: Determine how to keep alive delegates that are in struct fields. - default: - break; - } - } - - public virtual IEnumerable GenerateAdditionalNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) - { - return Array.Empty(); - } - - public virtual IEnumerable GenerateIntermediateMarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - return Array.Empty(); - } - - public virtual IEnumerable GeneratePreUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - return Array.Empty(); - } - - public virtual IEnumerable GenerateIntermediateUnmarshallingStatements(TypePositionInfo info, StubCodeContext context) - { - return Array.Empty(); - } - - public virtual IEnumerable GenerateIntermediateCleanupStatements(TypePositionInfo info, StubCodeContext context) - { - return Array.Empty(); - } - - public virtual bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) - { - if (info.IsManagedReturnPosition || info.IsByRef && info.RefKind != RefKind.In) - { - return true; - } - if (context.PinningSupported) - { - if (!info.IsByRef && (_marshallingMethods & SupportedMarshallingMethods.Pinning) != 0) - { - return false; - } - else if (_marshalerTypePinnable) - { - return false; - } - } - return true; - } - - public virtual bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) => false; - } -} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs new file mode 100644 index 000000000000..849af1258092 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + /// + /// Implements generating code for an instance. + /// + internal sealed class CustomNativeTypeMarshallingGenerator : IMarshallingGenerator + { + private readonly ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller; + + public CustomNativeTypeMarshallingGenerator(ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) + { + this.nativeTypeMarshaller = nativeTypeMarshaller; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return nativeTypeMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return nativeTypeMarshaller.AsNativeType(info); + } + + public ParameterSyntax AsParameter(TypePositionInfo info) + { + var type = info.IsByRef + ? PointerType(AsNativeType(info)) + : AsNativeType(info); + return Parameter(Identifier(info.InstanceIdentifier)) + .WithType(type); + } + + public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + { + switch (context.CurrentStage) + { + case StubCodeContext.Stage.Setup: + return nativeTypeMarshaller.GenerateSetupStatements(info, context); + case StubCodeContext.Stage.Marshal: + if (!info.IsManagedReturnPosition && info.RefKind != RefKind.Out) + { + return nativeTypeMarshaller.GenerateMarshalStatements(info, context, nativeTypeMarshaller.GetNativeTypeConstructorArguments(info, context)); + } + break; + case StubCodeContext.Stage.Pin: + if (!info.IsByRef || info.RefKind == RefKind.In) + { + return nativeTypeMarshaller.GeneratePinStatements(info, context); + } + break; + case StubCodeContext.Stage.Unmarshal: + if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + { + return nativeTypeMarshaller.GenerateUnmarshalStatements(info, context); + } + break; + case StubCodeContext.Stage.Cleanup: + return nativeTypeMarshaller.GenerateCleanupStatements(info, context); + default: + break; + } + + return Array.Empty(); + } + + public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) + { + return false; + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return nativeTypeMarshaller.UsesNativeIdentifier(info, context); + } + } +} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs new file mode 100644 index 000000000000..4ac1cdc52e7f --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs @@ -0,0 +1,969 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + /// + /// The base interface for implementing various different aspects of the custom native type and collection marshalling specs. + /// + + interface ICustomNativeTypeMarshallingStrategy + { + TypeSyntax AsNativeType(TypePositionInfo info); + + ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context); + + IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context); + + IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments); + + IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context); + + IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context); + + IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context); + + IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context); + + bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context); + } + + /// + /// Marshalling support for a type that has a custom native type. + /// + internal sealed class SimpleCustomNativeTypeMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly TypeSyntax nativeTypeSyntax; + + public SimpleCustomNativeTypeMarshalling(TypeSyntax nativeTypeSyntax) + { + this.nativeTypeSyntax = nativeTypeSyntax; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + string identifier = context.GetIdentifiers(info).native; + if (info.IsByRef) + { + return Argument( + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + IdentifierName(identifier))); + } + + return Argument(IdentifierName(identifier)); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return nativeTypeSyntax; + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return true; + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + // = new(); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetIdentifiers(info).native), + ImplicitObjectCreationExpression() + .WithArgumentList(ArgumentList(SeparatedList(nativeTypeConstructorArguments))))); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + // If the current element is being marshalled by-value [Out], then don't call the ToManaged method and do the assignment. + // The assignment will end up being a no-op and will not be observed. + if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + { + yield break; + } + + var (managedIdentifier, nativeIdentifier) = context.GetIdentifiers(info); + // = .ToManaged(); + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(managedIdentifier), + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.ToManagedMethodName))))); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + yield return Argument(IdentifierName(context.GetIdentifiers(info).managed)); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return Array.Empty(); + } + } + + /// + /// A context that redefines the 'native' identifier for a TypePositionInfo to be the marshaller identifier. + /// + internal class CustomNativeTypeWithValuePropertyStubContext : StubCodeContext + { + private readonly StubCodeContext parentContext; + + public CustomNativeTypeWithValuePropertyStubContext(StubCodeContext parentContext) + { + this.parentContext = parentContext; + } + + public override bool PinningSupported => parentContext.PinningSupported; + + public override bool StackSpaceUsable => parentContext.StackSpaceUsable; + + public override bool CanUseAdditionalTemporaryState => parentContext.CanUseAdditionalTemporaryState; + + public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index) + { + return parentContext.GetTypePositionInfoForManagedIndex(index); + } + + public override (string managed, string native) GetIdentifiers(TypePositionInfo info) + { + return (parentContext.GetIdentifiers(info).managed, MarshallerHelpers.GetMarshallerIdentifier(info, parentContext)); + } + } + + /// + /// Marshaller that enables support of a Value property on a native type. + /// + internal sealed class CustomNativeTypeWithValuePropertyMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + private readonly TypeSyntax valuePropertyType; + + public CustomNativeTypeWithValuePropertyMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, TypeSyntax valuePropertyType) + { + this.innerMarshaller = innerMarshaller; + this.valuePropertyType = valuePropertyType; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + string identifier = context.GetIdentifiers(info).native; + if (info.IsByRef) + { + return Argument( + PrefixUnaryExpression( + SyntaxKind.AddressOfExpression, + IdentifierName(identifier))); + } + + return Argument(IdentifierName(identifier)); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return valuePropertyType; + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return true; + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + return innerMarshaller.GenerateCleanupStatements(info, subContext); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, subContext, nativeTypeConstructorArguments)) + { + yield return statement; + } + + // = .Value; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetIdentifiers(info).native), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)))); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + + // .Value = ; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), + IdentifierName(context.GetIdentifiers(info).native))); + + foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, subContext)) + { + yield return statement; + } + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + return innerMarshaller.GetNativeTypeConstructorArguments(info, subContext); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + yield return LocalDeclarationStatement( + VariableDeclaration( + innerMarshaller.AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(subContext.GetIdentifiers(info).native) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression)))))); + + foreach (var statement in innerMarshaller.GenerateSetupStatements(info, subContext)) + { + yield return statement; + } + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + return innerMarshaller.GeneratePinStatements(info, subContext); + } + } + + /// + /// Marshaller that enables support for a stackalloc constructor variant on a native type. + /// + internal sealed class StackallocOptimizationMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + + public StackallocOptimizationMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller) + { + this.innerMarshaller = innerMarshaller; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateCleanupStatements(info, context); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + if (StackAllocOptimizationValid(info, context)) + { + // byte* __stackptr = stackalloc byte[<_nativeLocalType>.StackBufferSize]; + yield return LocalDeclarationStatement( + VariableDeclaration( + PointerType(PredefinedType(Token(SyntaxKind.ByteKeyword))), + SingletonSeparatedList( + VariableDeclarator(GetStackAllocPointerIdentifier(info, context)) + .WithInitializer(EqualsValueClause( + StackAllocArrayCreationExpression( + ArrayType( + PredefinedType(Token(SyntaxKind.ByteKeyword)), + SingletonList(ArrayRankSpecifier(SingletonSeparatedList( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + AsNativeType(info), + IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName)) + )))))))))); + } + + foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments)) + { + yield return statement; + } + } + + private static bool StackAllocOptimizationValid(TypePositionInfo info, StubCodeContext context) + { + return context.StackSpaceUsable && context.PinningSupported && (!info.IsByRef || info.RefKind == RefKind.In); + } + + private static string GetStackAllocPointerIdentifier(TypePositionInfo info, StubCodeContext context) + { + return $"{context.GetIdentifiers(info).managed}__stackptr"; + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateUnmarshalStatements(info, context); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + foreach (var arg in innerMarshaller.GetNativeTypeConstructorArguments(info, context)) + { + yield return arg; + } + if (StackAllocOptimizationValid(info, context)) + { + yield return Argument( + ObjectCreationExpression( + GenericName(Identifier(TypeNames.System_Span), + TypeArgumentList(SingletonSeparatedList( + PredefinedType(Token(SyntaxKind.ByteKeyword)))))) + .WithArgumentList( + ArgumentList(SeparatedList(new ArgumentSyntax[] + { + Argument(IdentifierName(GetStackAllocPointerIdentifier(info, context))), + Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + AsNativeType(info), + IdentifierName(ManualTypeMarshallingHelper.StackBufferSizeFieldName))) + })))); + } + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.UsesNativeIdentifier(info, context); + } + } + + /// + /// Marshaller that enables support for a FreeNative method on a native type. + /// + internal sealed class FreeNativeCleanupStrategy : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + + public FreeNativeCleanupStrategy(ICustomNativeTypeMarshallingStrategy innerMarshaller) + { + this.innerMarshaller = innerMarshaller; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (var statement in innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + + // .FreeNative(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(context.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.FreeNativeMethodName)))); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + return innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateUnmarshalStatements(info, context); + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GetNativeTypeConstructorArguments(info, context); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.UsesNativeIdentifier(info, context); + } + } + + /// + /// Marshaller that enables support for a GetPinnableReference method on a native type, with a Value property fallback. + /// + internal sealed class PinnableMarshallerTypeMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + private readonly TypeSyntax valuePropertyType; + + public PinnableMarshallerTypeMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, TypeSyntax valuePropertyType) + { + this.innerMarshaller = innerMarshaller; + this.valuePropertyType = valuePropertyType; + } + + private bool CanPinMarshaller(TypePositionInfo info, StubCodeContext context) + { + return !info.IsByRef || info.RefKind == RefKind.In; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return valuePropertyType; + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + return innerMarshaller.GenerateCleanupStatements(info, subContext); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, subContext, nativeTypeConstructorArguments)) + { + yield return statement; + } + + if (!CanPinMarshaller(info, context)) + { + // = .Value; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetIdentifiers(info).native), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)))); + } + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + // fixed (<_nativeTypeSyntax> = &) + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + yield return FixedStatement( + VariableDeclaration( + innerMarshaller.AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(context.GetIdentifiers(info).native) + .WithInitializer(EqualsValueClause( + PrefixUnaryExpression(SyntaxKind.AddressOfExpression, + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.GetPinnableReferenceName)), + ArgumentList())))))), + EmptyStatement()); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + yield return LocalDeclarationStatement( + VariableDeclaration( + innerMarshaller.AsNativeType(info), + SingletonSeparatedList( + VariableDeclarator(subContext.GetIdentifiers(info).native) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression)))))); + + foreach (var statement in innerMarshaller.GenerateSetupStatements(info, subContext)) + { + yield return statement; + } + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); + + // .Value = ; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), + IdentifierName(context.GetIdentifiers(info).native))); + + foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, subContext)) + { + yield return statement; + } + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GetNativeTypeConstructorArguments(info, context); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + if (context.PinningSupported) + { + return false; + } + return innerMarshaller.UsesNativeIdentifier(info, context); + } + } + + /// + /// Marshaller that enables support for native types with the constructor variants that take a sizeOfElement int parameter and that have a SetUnmarshalledCollectionLength method. + /// + internal sealed class NumElementsExpressionMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + private readonly ExpressionSyntax numElementsExpression; + private readonly ExpressionSyntax sizeOfElementExpression; + + public NumElementsExpressionMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, ExpressionSyntax numElementsExpression, ExpressionSyntax sizeOfElementExpression) + { + this.innerMarshaller = innerMarshaller; + this.numElementsExpression = numElementsExpression; + this.sizeOfElementExpression = sizeOfElementExpression; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateCleanupStatements(info, context); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + return innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + string marshalerIdentifier = MarshallerHelpers.GetMarshallerIdentifier(info, context); + if (info.RefKind == RefKind.Out || info.IsManagedReturnPosition) + { + yield return ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(marshalerIdentifier), + ImplicitObjectCreationExpression().AddArgumentListArguments(Argument(sizeOfElementExpression)))); + } + + if (info.IsByRef || !info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + { + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshalerIdentifier), + IdentifierName(ManualTypeMarshallingHelper.SetUnmarshalledCollectionLengthMethodName))) + .AddArgumentListArguments(Argument(numElementsExpression))); + } + + foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, context)) + { + yield return statement; + } + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + foreach (var arg in innerMarshaller.GetNativeTypeConstructorArguments(info, context)) + { + yield return arg; + } + yield return Argument(sizeOfElementExpression); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.UsesNativeIdentifier(info, context); + } + } + + /// + /// Marshaller that enables support for marshalling blittable elements of a contiguous collection via a native type that implements the contiguous collection marshalling spec. + /// + internal sealed class ContiguousBlittableElementCollectionMarshalling : ICustomNativeTypeMarshallingStrategy + { + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + private readonly TypeSyntax elementType; + + public ContiguousBlittableElementCollectionMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, TypeSyntax elementType) + { + this.innerMarshaller = innerMarshaller; + this.elementType = elementType; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateCleanupStatements(info, context); + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + string nativeIdentifier = context.GetIdentifiers(info).native; + foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments)) + { + yield return statement; + } + + // .ManagedValues.CopyTo(MemoryMarshal.Cast>(.NativeValueStorage)); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.ManagedValuesPropertyName)), + IdentifierName("CopyTo"))) + .AddArgumentListArguments( + Argument( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new[] + { + PredefinedType(Token(SyntaxKind.ByteKeyword)), + elementType + }))))) + .AddArgumentListArguments( + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName))))))); + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + string nativeIdentifier = context.GetIdentifiers(info).native; + // MemoryMarshal.Cast>(.NativeValueStorage).CopyTo(.ManagedValues); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new[] + { + PredefinedType(Token(SyntaxKind.ByteKeyword)), + elementType + }))))) + .AddArgumentListArguments( + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName)))), + IdentifierName("CopyTo"))) + .AddArgumentListArguments( + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.ManagedValuesPropertyName))))); + + foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, context)) + { + yield return statement; + } + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GetNativeTypeConstructorArguments(info, context); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.UsesNativeIdentifier(info, context); + } + } + + /// + /// Marshaller that enables support for marshalling non-blittable elements of a contiguous collection via a native type that implements the contiguous collection marshalling spec. + /// + internal sealed class ContiguousNonBlittableElementCollectionMarshalling : ICustomNativeTypeMarshallingStrategy + { + private const string IndexerIdentifier = "__i"; + + private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller; + private readonly IMarshallingGenerator elementMarshaller; + private readonly TypePositionInfo elementInfo; + + public ContiguousNonBlittableElementCollectionMarshalling(ICustomNativeTypeMarshallingStrategy innerMarshaller, + IMarshallingGenerator elementMarshaller, + TypePositionInfo elementInfo) + { + this.innerMarshaller = innerMarshaller; + this.elementMarshaller = elementMarshaller; + this.elementInfo = elementInfo; + } + + private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext context) + { + return context.GetIdentifiers(info).managed + "__nativeSpan"; + } + + private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositionInfo info, StubCodeContext context) + { + string nativeIdentifier = context.GetIdentifiers(info).native; + string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); + return LocalDeclarationStatement(VariableDeclaration( + GenericName( + Identifier(TypeNames.System_Span), + TypeArgumentList( + SingletonSeparatedList(elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax())) + ), + SingletonSeparatedList( + VariableDeclarator(Identifier(nativeSpanIdentifier)) + .WithInitializer(EqualsValueClause( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName(TypeNames.System_Runtime_InteropServices_MemoryMarshal), + GenericName( + Identifier("Cast")) + .WithTypeArgumentList( + TypeArgumentList( + SeparatedList( + new[] + { + PredefinedType(Token(SyntaxKind.ByteKeyword)), + elementMarshaller.AsNativeType(elementInfo).GetCompatibleGenericTypeParameterSyntax() + }))))) + .AddArgumentListArguments( + Argument(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(nativeIdentifier), + IdentifierName(ManualTypeMarshallingHelper.NativeValueStoragePropertyName))))))))); + } + + private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength) + { + string nativeIdentifier = context.GetIdentifiers(info).native; + string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context); + var elementSubContext = new ContiguousCollectionElementMarshallingCodeContext( + context.CurrentStage, + IndexerIdentifier, + nativeSpanIdentifier, + context); + + string collectionIdentifierForLength = useManagedSpanForLength + ? $"{nativeIdentifier}.{ManualTypeMarshallingHelper.ManagedValuesPropertyName}" + : nativeSpanIdentifier; + + TypePositionInfo localElementInfo = elementInfo with + { + InstanceIdentifier = info.InstanceIdentifier, + RefKind = info.IsByRef ? info.RefKind : info.ByValueContentsMarshalKind.GetRefKindForByValueContentsKind(), + ManagedIndex = info.ManagedIndex, + NativeIndex = info.NativeIndex + }; + + StatementSyntax marshallingStatement = Block( + List(elementMarshaller.Generate( + localElementInfo, + elementSubContext))); + + if (elementMarshaller.AsNativeType(elementInfo) is PointerTypeSyntax) + { + PointerNativeTypeAssignmentRewriter rewriter = new(elementSubContext.GetIdentifiers(localElementInfo).native); + marshallingStatement = (StatementSyntax)rewriter.Visit(marshallingStatement); + } + + // Iterate through the elements of the native collection to unmarshal them + return Block( + GenerateNativeSpanDeclaration(info, context), + MarshallerHelpers.GetForLoop(collectionIdentifierForLength, IndexerIdentifier) + .WithStatement(marshallingStatement)); + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return innerMarshaller.AsNativeType(info); + } + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); + foreach (var statement in innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + } + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); + foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments)) + { + yield return statement; + } + } + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GeneratePinStatements(info, context); + } + + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GenerateSetupStatements(info, context); + } + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) + { + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: false); + foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, context)) + { + yield return statement; + } + } + + public IEnumerable GetNativeTypeConstructorArguments(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.GetNativeTypeConstructorArguments(info, context); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + return innerMarshaller.UsesNativeIdentifier(info, context); + } + + /// + /// Rewrite assignment expressions to the native identifier to cast to IntPtr. + /// This handles the case where the native type of a non-blittable managed type is a pointer, + /// which are unsupported in generic type parameters. + /// + private class PointerNativeTypeAssignmentRewriter : CSharpSyntaxRewriter + { + private readonly string nativeIdentifier; + + public PointerNativeTypeAssignmentRewriter(string nativeIdentifier) + { + this.nativeIdentifier = nativeIdentifier; + } + + public override SyntaxNode VisitAssignmentExpression(AssignmentExpressionSyntax node) + { + if (node.Left.ToString() == nativeIdentifier) + { + return node.WithRight( + CastExpression(MarshallerHelpers.SystemIntPtrType, node.Right)); + } + + return node; + } + } + } +} diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index e31615c405de..fe30c208b7f8 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -411,6 +411,43 @@ ExpressionSyntax GetExpressionForParam(TypePositionInfo? paramInfo) } private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo, AnalyzerConfigOptions options) + { + ValidateCustomNativeTypeMarshallingSupported(info, context, marshalInfo); + + ICustomNativeTypeMarshallingStrategy marshallingStrategy = new SimpleCustomNativeTypeMarshalling(marshalInfo.NativeMarshallingType.AsTypeSyntax()); + + if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0) + { + marshallingStrategy = new StackallocOptimizationMarshalling(marshallingStrategy); + } + + if (ManualTypeMarshallingHelper.HasFreeNativeMethod(marshalInfo.NativeMarshallingType)) + { + marshallingStrategy = new FreeNativeCleanupStrategy(marshallingStrategy); + } + + // Collections have extra configuration, so handle them here. + if (marshalInfo is NativeContiguousCollectionMarshallingInfo collectionMarshallingInfo) + { + return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, options, marshallingStrategy); + } + + if (marshalInfo.ValuePropertyType is not null) + { + marshallingStrategy = DecorateWithValuePropertyStrategy(marshalInfo, marshallingStrategy); + } + + IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy); + + if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) + { + return new PinnableManagedValueMarshaller(marshallingGenerator); + } + + return marshallingGenerator; + } + + private static void ValidateCustomNativeTypeMarshallingSupported(TypePositionInfo info, StubCodeContext context, NativeMarshallingAttributeInfo marshalInfo) { if (marshalInfo.ValuePropertyType is not null && !context.CanUseAdditionalTemporaryState) { @@ -422,7 +459,7 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi // The marshalling method for this type doesn't support marshalling from native to managed, // but our scenario requires marshalling from native to managed. - if ((info.RefKind == RefKind.Ref || info.RefKind == RefKind.Out || info.IsManagedReturnPosition) + if ((info.RefKind == RefKind.Ref || info.RefKind == RefKind.Out || info.IsManagedReturnPosition) && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.NativeToManaged) == 0) { throw new MarshallingNotSupportedException(info, context) @@ -435,9 +472,9 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi // Pinning is required for the stackalloc marshalling to enable users to safely pass the stackalloc Span's byref // to native if we ever start using a conditional stackalloc method and cannot guarantee that the Span we provide // the user with is backed by stack allocated memory. - else if (!info.IsByRef - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 - && !(context.PinningSupported && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) == 0) + else if (!info.IsByRef + && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 + && !(context.PinningSupported && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) == 0) && !(context.StackSpaceUsable && context.PinningSupported && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) == 0)) { throw new MarshallingNotSupportedException(info, context) @@ -448,8 +485,8 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi // The marshalling method for this type doesn't support marshalling from managed to native by reference, // but our scenario requires marshalling from managed to native by reference. // "in" byref supports stack marshalling. - else if (info.RefKind == RefKind.In - && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 + else if (info.RefKind == RefKind.In + && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0 && !(context.StackSpaceUsable && context.PinningSupported && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNativeStackalloc) != 0)) { throw new MarshallingNotSupportedException(info, context) @@ -460,7 +497,7 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi // The marshalling method for this type doesn't support marshalling from managed to native by reference, // but our scenario requires marshalling from managed to native by reference. // "ref" byref marshalling doesn't support stack marshalling - else if (info.RefKind == RefKind.Ref + else if (info.RefKind == RefKind.Ref && (marshalInfo.MarshallingMethods & SupportedMarshallingMethods.ManagedToNative) == 0) { throw new MarshallingNotSupportedException(info, context) @@ -468,43 +505,78 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi NotSupportedDetails = string.Format(Resources.CustomTypeMarshallingManagedToNativeUnsupported, marshalInfo.NativeMarshallingType.ToDisplayString()) }; } + } - if (marshalInfo is NativeContiguousCollectionMarshallingInfo collectionMarshallingInfo) + private static ICustomNativeTypeMarshallingStrategy DecorateWithValuePropertyStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) + { + TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.AsTypeSyntax(); + if (ManualTypeMarshallingHelper.FindGetPinnableReference(marshalInfo.ValuePropertyType!) is not null) { - return CreateNativeCollectionMarshaller(info, context, collectionMarshallingInfo, options); + return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, valuePropertyTypeSyntax); } - - return new CustomNativeTypeMarshaller(marshalInfo); + + return new CustomNativeTypeWithValuePropertyMarshalling(nativeTypeMarshaller, valuePropertyTypeSyntax); } - private static IMarshallingGenerator CreateNativeCollectionMarshaller(TypePositionInfo info, StubCodeContext context, NativeContiguousCollectionMarshallingInfo collectionMarshallingInfo, AnalyzerConfigOptions options) + private static IMarshallingGenerator CreateNativeCollectionMarshaller( + TypePositionInfo info, + StubCodeContext context, + NativeContiguousCollectionMarshallingInfo collectionInfo, + AnalyzerConfigOptions options, + ICustomNativeTypeMarshallingStrategy marshallingStrategy) { - var elementInfo = TypePositionInfo.CreateForType(collectionMarshallingInfo.ElementType, collectionMarshallingInfo.ElementMarshallingInfo); + var elementInfo = TypePositionInfo.CreateForType(collectionInfo.ElementType, collectionInfo.ElementMarshallingInfo); var elementMarshaller = Create( elementInfo, new ContiguousCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, string.Empty, context), options); + var elementType = elementMarshaller.AsNativeType(elementInfo); + + bool isBlittable = elementMarshaller == Blittable; + + if (isBlittable) + { + marshallingStrategy = new ContiguousBlittableElementCollectionMarshalling(marshallingStrategy, collectionInfo.ElementType.AsTypeSyntax()); + } + else + { + marshallingStrategy = new ContiguousNonBlittableElementCollectionMarshalling(marshallingStrategy, elementMarshaller, elementInfo); + } + + // Explicitly insert the Value property handling here (before numElements handling) so that the numElements handling will be emitted before the Value property handling in unmarshalling. + if (collectionInfo.ValuePropertyType is not null) + { + marshallingStrategy = DecorateWithValuePropertyStrategy(collectionInfo, marshallingStrategy); + } + ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)); if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) { // In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here. - numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionMarshallingInfo.ElementCountInfo, context, options); + numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionInfo.ElementCountInfo, context, options); } - if (collectionMarshallingInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true }) + marshallingStrategy = new NumElementsExpressionMarshalling( + marshallingStrategy, + numElementsExpression, + SizeOfExpression(elementType)); + + if (collectionInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true }) { - if (elementMarshaller == Blittable) - { - return new ArrayMarshaller(new ContiguousCollectionBlittableElementsMarshallingGenerator(collectionMarshallingInfo, numElementsExpression), collectionMarshallingInfo); - } - return new ArrayMarshaller(new ContiguousCollectionNonBlittableElementsMarshallingGenerator(collectionMarshallingInfo, elementMarshaller, elementInfo, numElementsExpression), collectionMarshallingInfo); + return new ArrayMarshaller( + new CustomNativeTypeMarshallingGenerator(marshallingStrategy), + elementType, + isBlittable); } - if (elementMarshaller == Blittable) + IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy); + + if ((collectionInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) { - return new ContiguousCollectionBlittableElementsMarshallingGenerator(collectionMarshallingInfo, numElementsExpression); + return new PinnableManagedValueMarshaller(marshallingGenerator); } - return new ContiguousCollectionNonBlittableElementsMarshallingGenerator(collectionMarshallingInfo, elementMarshaller, elementInfo, numElementsExpression); + + return marshallingGenerator; } } } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs new file mode 100644 index 000000000000..1437ace86499 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs @@ -0,0 +1,84 @@ +using System.Collections.Generic; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Interop +{ + internal class PinnableManagedValueMarshaller : IMarshallingGenerator + { + private readonly IMarshallingGenerator manualMarshallingGenerator; + + public PinnableManagedValueMarshaller(IMarshallingGenerator manualMarshallingGenerator) + { + this.manualMarshallingGenerator = manualMarshallingGenerator; + } + + public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) + { + if (IsPinningPathSupported(info, context)) + { + string identifier = context.GetIdentifiers(info).native; + return Argument(CastExpression(AsNativeType(info), IdentifierName(identifier))); + } + return manualMarshallingGenerator.AsArgument(info, context); + } + + public TypeSyntax AsNativeType(TypePositionInfo info) + { + return manualMarshallingGenerator.AsNativeType(info); + } + + public ParameterSyntax AsParameter(TypePositionInfo info) + { + return manualMarshallingGenerator.AsParameter(info); + } + + public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) + { + if (IsPinningPathSupported(info, context)) + { + return GeneratePinningPath(info, context); + } + return manualMarshallingGenerator.Generate(info, context); + } + + public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) + { + return manualMarshallingGenerator.SupportsByValueMarshalKind(marshalKind, context); + } + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) + { + if (IsPinningPathSupported(info, context)) + { + return false; + } + return manualMarshallingGenerator.UsesNativeIdentifier(info, context); + } + private static bool IsPinningPathSupported(TypePositionInfo info, StubCodeContext context) + { + return context.PinningSupported && !info.IsByRef && !info.IsManagedReturnPosition; + } + + private IEnumerable GeneratePinningPath(TypePositionInfo info, StubCodeContext context) + { + if (context.CurrentStage == StubCodeContext.Stage.Pin) + { + var (managedIdentifier, nativeIdentifier) = context.GetIdentifiers(info); + yield return FixedStatement( + VariableDeclaration( + PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))), + SingletonSeparatedList( + VariableDeclarator(Identifier(nativeIdentifier)) + .WithInitializer(EqualsValueClause( + IdentifierName(managedIdentifier) + )) + ) + ), + EmptyStatement() + ); + } + } + } +} From a6eeac9bb7f6e702a2f78d5f604e33c933876032 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 7 Jun 2021 17:25:51 -0700 Subject: [PATCH 27/30] Use enum instead of a bool. --- .../Analyzers/ManualTypeMarshallingAnalyzer.cs | 4 ++-- .../ManualTypeMarshallingHelper.cs | 14 ++++++++++---- .../DllImportGenerator/MarshallingAttributeInfo.cs | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs b/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs index 68b935619852..965592e6522a 100644 --- a/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs +++ b/DllImportGenerator/DllImportGenerator/Analyzers/ManualTypeMarshallingAnalyzer.cs @@ -364,9 +364,9 @@ private void AnalyzeNativeMarshalerType(SymbolAnalysisContext context, ITypeSymb continue; } - hasConstructor = hasConstructor || ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, false); + hasConstructor = hasConstructor || ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, ManualTypeMarshallingHelper.NativeTypeMarshallingVariant.Standard); - if (!hasStackallocConstructor && ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, SpanOfByte, false)) + if (!hasStackallocConstructor && ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, SpanOfByte, ManualTypeMarshallingHelper.NativeTypeMarshallingVariant.Standard)) { hasStackallocConstructor = true; IFieldSymbol stackAllocSizeField = nativeType.GetMembers("StackBufferSize").OfType().FirstOrDefault(); diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index ab9000d94bad..ce6c521ba1f0 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -23,6 +23,12 @@ public static class MarshalUsingProperties public const string ConstantElementCount = "ConstantElementCount"; } + public enum NativeTypeMarshallingVariant + { + Standard, + ContiguousCollection + } + public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol managedType) { return nativeType.GetMembers(ToManagedMethodName) @@ -37,9 +43,9 @@ public static bool HasToManagedMethod(ITypeSymbol nativeType, ITypeSymbol manage public static bool IsManagedToNativeConstructor( IMethodSymbol ctor, ITypeSymbol managedType, - bool isCollectionMarshaller) + NativeTypeMarshallingVariant variant) { - if (isCollectionMarshaller) + if (variant == NativeTypeMarshallingVariant.ContiguousCollection) { return ctor.Parameters.Length == 2 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) @@ -53,9 +59,9 @@ public static bool IsStackallocConstructor( IMethodSymbol ctor, ITypeSymbol managedType, ITypeSymbol spanOfByte, - bool isCollectionMarshaller) + NativeTypeMarshallingVariant variant) { - if (isCollectionMarshaller) + if (variant == NativeTypeMarshallingVariant.ContiguousCollection) { return ctor.Parameters.Length == 3 && SymbolEqualityComparer.Default.Equals(managedType, ctor.Parameters[0].Type) diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index 47a481ab324a..844934e619e9 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -550,15 +550,18 @@ MarshallingInfo CreateNativeMarshallingInfo( bool isContiguousCollectionMarshaller = nativeType.GetAttributes().Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, contiguousCollectionMarshalerAttribute)); IPropertySymbol? valueProperty = ManualTypeMarshallingHelper.FindValueProperty(nativeType); + var marshallingVariant = isContiguousCollectionMarshaller + ? ManualTypeMarshallingHelper.NativeTypeMarshallingVariant.ContiguousCollection + : ManualTypeMarshallingHelper.NativeTypeMarshallingVariant.Standard; + bool hasInt32Constructor = false; foreach (var ctor in nativeType.Constructors) { - if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, isCollectionMarshaller: isContiguousCollectionMarshaller) - && (valueProperty is null or { GetMethod: not null })) + if (ManualTypeMarshallingHelper.IsManagedToNativeConstructor(ctor, type, marshallingVariant) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNative; } - else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, isCollectionMarshaller: isContiguousCollectionMarshaller) + else if (ManualTypeMarshallingHelper.IsStackallocConstructor(ctor, type, spanOfByte, marshallingVariant) && (valueProperty is null or { GetMethod: not null })) { methods |= SupportedMarshallingMethods.ManagedToNativeStackalloc; From 08450159c0b90ad92a72b0f6b99cd00cad4c00e7 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 7 Jun 2021 17:27:46 -0700 Subject: [PATCH 28/30] Seal more types. --- .../DllImportGenerator/Marshalling/ArrayMarshaller.cs | 2 +- .../Marshalling/PinnableManagedValueMarshaller.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index c5838517bb3d..be391398fdf0 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -5,7 +5,7 @@ namespace Microsoft.Interop { - internal class ArrayMarshaller : IMarshallingGenerator + internal sealed class ArrayMarshaller : IMarshallingGenerator { private readonly IMarshallingGenerator manualMarshallingGenerator; private readonly TypeSyntax elementType; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs index 1437ace86499..eff8a6dfa055 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/PinnableManagedValueMarshaller.cs @@ -5,7 +5,7 @@ namespace Microsoft.Interop { - internal class PinnableManagedValueMarshaller : IMarshallingGenerator + internal sealed class PinnableManagedValueMarshaller : IMarshallingGenerator { private readonly IMarshallingGenerator manualMarshallingGenerator; From a5efcd9a1ca028583ce324631397ed32a8433710 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 8 Jun 2021 10:36:40 -0700 Subject: [PATCH 29/30] PR feedback. --- .../ManualTypeMarshallingHelper.cs | 15 +++++++-------- .../ICustomNativeTypeMarshallingStrategy.cs | 1 - .../MarshallingAttributeInfo.cs | 6 +++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs index ce6c521ba1f0..1d073064a39d 100644 --- a/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs +++ b/DllImportGenerator/DllImportGenerator/ManualTypeMarshallingHelper.cs @@ -18,9 +18,9 @@ static class ManualTypeMarshallingHelper public static class MarshalUsingProperties { - public const string ElementIndirectionLevel = "ElementIndirectionLevel"; - public const string CountElementName = "CountElementName"; - public const string ConstantElementCount = "ConstantElementCount"; + public const string ElementIndirectionLevel = nameof(ElementIndirectionLevel); + public const string CountElementName = nameof(CountElementName); + public const string ConstantElementCount = nameof(ConstantElementCount); } public enum NativeTypeMarshallingVariant @@ -99,19 +99,18 @@ public static bool HasFreeNativeMethod(ITypeSymbol type) .Any(m => m is { IsStatic: false, Parameters: { Length: 0 }, ReturnType: { SpecialType: SpecialType.System_Void } }); } - public static IPropertySymbol? FindManagedValuesProperty(ITypeSymbol type) + public static bool TryGetManagedValuesProperty(ITypeSymbol type, out IPropertySymbol managedValuesProperty) { - return type + managedValuesProperty = type .GetMembers(ManagedValuesPropertyName) .OfType() .FirstOrDefault(p => p is { IsStatic: false, GetMethod: not null, ReturnsByRef: false, ReturnsByRefReadonly: false }); + return managedValuesProperty is not null; } public static bool TryGetElementTypeFromContiguousCollectionMarshaller(ITypeSymbol type, out ITypeSymbol elementType) { - IPropertySymbol? managedValuesProperty = FindManagedValuesProperty(type); - - if (managedValuesProperty is null) + if (!TryGetManagedValuesProperty(type, out IPropertySymbol managedValuesProperty)) { elementType = null!; return false; diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs index 4ac1cdc52e7f..e3999e72f3a4 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs @@ -11,7 +11,6 @@ namespace Microsoft.Interop /// /// The base interface for implementing various different aspects of the custom native type and collection marshalling specs. /// - interface ICustomNativeTypeMarshallingStrategy { TypeSyntax AsNativeType(TypePositionInfo info); diff --git a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs index 844934e619e9..1590fc49f696 100644 --- a/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs +++ b/DllImportGenerator/DllImportGenerator/MarshallingAttributeInfo.cs @@ -123,7 +123,7 @@ internal sealed record GeneratedNativeMarshallingAttributeInfo( internal sealed record SafeHandleMarshallingInfo(bool AccessibleDefaultConstructor) : MarshallingInfo; /// - /// User-applied System.Runtime.InteropServices.NativeMarshalllingAttribute + /// User-applied System.Runtime.InteropServices.NativeMarshallingAttribute /// with a contiguous collection marshaller internal sealed record NativeContiguousCollectionMarshallingInfo( ITypeSymbol NativeMarshallingType, @@ -610,7 +610,7 @@ MarshallingInfo CreateNativeMarshallingInfo( UseDefaultMarshalling: !isMarshalUsingAttribute, parsedCountInfo, elementType, - GetMarshallingInfo(elementType, useSiteAttributes, maxIndirectionLevelUsed = indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); + GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); } return new NativeMarshallingAttributeInfo( @@ -680,7 +680,7 @@ bool TryCreateTypeBasedMarshallingInfo( UseDefaultMarshalling: true, ElementCountInfo: parsedCountInfo, ElementType: elementType, - ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, maxIndirectionLevelUsed = indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); + ElementMarshallingInfo: GetMarshallingInfo(elementType, useSiteAttributes, indirectionLevel + 1, inspectedElements, ref maxIndirectionLevelUsed)); return true; } From 2836c29deade6be92116cc835f6d95b0b47367df Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Tue, 8 Jun 2021 11:54:00 -0700 Subject: [PATCH 30/30] Fix various test failures around element marshalling, by-value contents marshalling, and marshaller pinning. --- ...CollectionElementMarshallingCodeContext.cs | 4 +- .../Marshalling/ArrayMarshaller.cs | 6 +-- .../CustomNativeTypeMarshallingGenerator.cs | 11 +++-- .../ICustomNativeTypeMarshallingStrategy.cs | 41 +++++++++++++------ .../Marshalling/MarshallingGenerator.cs | 8 ++-- 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs index d7cc9123e731..89c3b03590de 100644 --- a/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs +++ b/DllImportGenerator/DllImportGenerator/ContiguousCollectionElementMarshallingCodeContext.cs @@ -57,9 +57,9 @@ public ContiguousCollectionElementMarshallingCodeContext( /// Managed and native identifiers public override (string managed, string native) GetIdentifiers(TypePositionInfo info) { - var (managed, _) = parentContext.GetIdentifiers(info); + var (_, native) = parentContext.GetIdentifiers(info); return ( - $"{MarshallerHelpers.GetMarshallerIdentifier(info, parentContext)}.ManagedValues[{indexerIdentifier}]", + $"{native}.ManagedValues[{indexerIdentifier}]", $"{nativeSpanIdentifier}[{indexerIdentifier}]" ); } diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs index be391398fdf0..6d0f98c533c4 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ArrayMarshaller.cs @@ -49,11 +49,11 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { - if (!(context.PinningSupported && enablePinning)) + if (context.PinningSupported && enablePinning) { - return marshalKind.HasFlag(ByValueContentsMarshalKind.Out); + return false; } - return manualMarshallingGenerator.SupportsByValueMarshalKind(marshalKind, context); + return marshalKind.HasFlag(ByValueContentsMarshalKind.Out); } public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs index 849af1258092..1f81e5040cf6 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/CustomNativeTypeMarshallingGenerator.cs @@ -12,10 +12,12 @@ namespace Microsoft.Interop internal sealed class CustomNativeTypeMarshallingGenerator : IMarshallingGenerator { private readonly ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller; + private readonly bool enableByValueContentsMarshalling; - public CustomNativeTypeMarshallingGenerator(ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) + public CustomNativeTypeMarshallingGenerator(ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller, bool enableByValueContentsMarshalling) { this.nativeTypeMarshaller = nativeTypeMarshaller; + this.enableByValueContentsMarshalling = enableByValueContentsMarshalling; } public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) @@ -39,6 +41,8 @@ public ParameterSyntax AsParameter(TypePositionInfo info) public IEnumerable Generate(TypePositionInfo info, StubCodeContext context) { + // Although custom native type marshalling doesn't support [In] or [Out] by value marshalling, + // other marshallers that wrap this one might, so we handle the correct cases here. switch (context.CurrentStage) { case StubCodeContext.Stage.Setup: @@ -56,7 +60,8 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont } break; case StubCodeContext.Stage.Unmarshal: - if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In)) + if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In) + || (enableByValueContentsMarshalling && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))) { return nativeTypeMarshaller.GenerateUnmarshalStatements(info, context); } @@ -72,7 +77,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont public bool SupportsByValueMarshalKind(ByValueContentsMarshalKind marshalKind, StubCodeContext context) { - return false; + return enableByValueContentsMarshalling; } public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs index e3999e72f3a4..283afaf60669 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/ICustomNativeTypeMarshallingStrategy.cs @@ -131,6 +131,7 @@ internal class CustomNativeTypeWithValuePropertyStubContext : StubCodeContext public CustomNativeTypeWithValuePropertyStubContext(StubCodeContext parentContext) { this.parentContext = parentContext; + CurrentStage = parentContext.CurrentStage; } public override bool PinningSupported => parentContext.PinningSupported; @@ -453,7 +454,7 @@ public PinnableMarshallerTypeMarshalling(ICustomNativeTypeMarshallingStrategy in private bool CanPinMarshaller(TypePositionInfo info, StubCodeContext context) { - return !info.IsByRef || info.RefKind == RefKind.In; + return context.PinningSupported && !info.IsManagedReturnPosition && !info.IsByRef || info.RefKind == RefKind.In; } public ArgumentSyntax AsArgument(TypePositionInfo info, StubCodeContext context) @@ -499,7 +500,7 @@ public IEnumerable GeneratePinStatements(TypePositionInfo info, var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); yield return FixedStatement( VariableDeclaration( - innerMarshaller.AsNativeType(info), + valuePropertyType, SingletonSeparatedList( VariableDeclarator(context.GetIdentifiers(info).native) .WithInitializer(EqualsValueClause( @@ -532,14 +533,17 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo { var subContext = new CustomNativeTypeWithValuePropertyStubContext(context); - // .Value = ; - yield return ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(subContext.GetIdentifiers(info).native), - IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), - IdentifierName(context.GetIdentifiers(info).native))); + if (!CanPinMarshaller(info, context)) + { + // .Value = ; + yield return ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(subContext.GetIdentifiers(info).native), + IdentifierName(ManualTypeMarshallingHelper.ValuePropertyName)), + IdentifierName(context.GetIdentifiers(info).native))); + } foreach (var statement in innerMarshaller.GenerateUnmarshalStatements(info, subContext)) { @@ -554,7 +558,7 @@ public IEnumerable GetNativeTypeConstructorArguments(TypePositio public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) { - if (context.PinningSupported) + if (CanPinMarshaller(info, context)) { return false; } @@ -687,6 +691,12 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i yield return statement; } + if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + { + // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. + yield break; + } + // .ManagedValues.CopyTo(MemoryMarshal.Cast>(.NativeValueStorage)); yield return ExpressionStatement( InvocationExpression( @@ -903,11 +913,18 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context, IEnumerable nativeTypeConstructorArguments) { - yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); foreach (var statement in innerMarshaller.GenerateMarshalStatements(info, context, nativeTypeConstructorArguments)) { yield return statement; } + + if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + { + // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. + yield break; + } + + yield return GenerateContentsMarshallingStatement(info, context, useManagedSpanForLength: true); } public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) diff --git a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs index fe30c208b7f8..f406e2c3e6d6 100644 --- a/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs +++ b/DllImportGenerator/DllImportGenerator/Marshalling/MarshallingGenerator.cs @@ -437,7 +437,7 @@ private static IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositi marshallingStrategy = DecorateWithValuePropertyStrategy(marshalInfo, marshallingStrategy); } - IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy); + IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); if ((marshalInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) { @@ -510,7 +510,7 @@ private static void ValidateCustomNativeTypeMarshallingSupported(TypePositionInf private static ICustomNativeTypeMarshallingStrategy DecorateWithValuePropertyStrategy(NativeMarshallingAttributeInfo marshalInfo, ICustomNativeTypeMarshallingStrategy nativeTypeMarshaller) { TypeSyntax valuePropertyTypeSyntax = marshalInfo.ValuePropertyType!.AsTypeSyntax(); - if (ManualTypeMarshallingHelper.FindGetPinnableReference(marshalInfo.ValuePropertyType!) is not null) + if (ManualTypeMarshallingHelper.FindGetPinnableReference(marshalInfo.NativeMarshallingType) is not null) { return new PinnableMarshallerTypeMarshalling(nativeTypeMarshaller, valuePropertyTypeSyntax); } @@ -564,12 +564,12 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller( if (collectionInfo.UseDefaultMarshalling && info.ManagedType is IArrayTypeSymbol { IsSZArray: true }) { return new ArrayMarshaller( - new CustomNativeTypeMarshallingGenerator(marshallingStrategy), + new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: true), elementType, isBlittable); } - IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy); + IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); if ((collectionInfo.MarshallingMethods & SupportedMarshallingMethods.Pinning) != 0) {