diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs index afa27b08e0e6d..a11ce0ecefdef 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs @@ -34,6 +34,7 @@ public LinearCollectionElementMarshallingCodeContext( _managedSpanIdentifier = managedSpanIdentifier; _nativeSpanIdentifier = nativeSpanIdentifier; ParentContext = parentContext; + Direction = ParentContext.Direction; } public override (TargetFramework framework, Version version) GetTargetFramework() diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs index 03b90462ea301..ee9809770583e 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs @@ -244,8 +244,22 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false); - if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + FreeStrategy freeStrategy = GetFreeStrategy(info, context); + + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + + if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free)) + { marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax); + } + + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); + } } IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false); @@ -311,19 +325,42 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax); } + FreeStrategy freeStrategy = GetFreeStrategy(info, context); IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource(); IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource); - marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling); + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + + marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree); + + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); + } + + if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + { + marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy); + } } else { marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression); + FreeStrategy freeStrategy = GetFreeStrategy(info, context); + IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax); + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy); + } + IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource); - marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape); + marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree); if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) { @@ -334,8 +371,15 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true); } - if (marshallerData.Shape.HasFlag(MarshallerShape.Free)) + if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free)) + { marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax); + } + + if (freeStrategy == FreeStrategy.FreeOriginal) + { + marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy); + } } IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator( @@ -351,6 +395,48 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller( return marshallingGenerator; } + private enum FreeStrategy + { + /// + /// Free the unmanaged value stored in the native identifier. + /// + FreeNative, + /// + /// Free the unmanaged value originally passed into the stub. + /// + FreeOriginal, + /// + /// Do not free the unmanaged value, we don't own it. + /// + NoFree + } + + private static FreeStrategy GetFreeStrategy(TypePositionInfo info, StubCodeContext context) + { + // When marshalling from managed to unmanaged, we always own the value in the native identifier. + if (context.Direction == MarshalDirection.ManagedToUnmanaged) + { + return FreeStrategy.FreeNative; + } + + // When we're in a case where we don't have state across stages, the parent stub context that can track the state + // will only call our Cleanup stage when we own the value in the native identifier. + if (!context.AdditionalTemporaryStateLivesAcrossStages) + { + return FreeStrategy.FreeNative; + } + + // In an unmanaged-to-managed stub where a value is passed by 'ref', + // we own the original value once we replace it with the new value we're passing out to the caller. + if (info.RefKind == RefKind.Ref) + { + return FreeStrategy.FreeOriginal; + } + + // In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'. + return FreeStrategy.NoFree; + } + private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource) { IElementsMarshalling elementsMarshalling; diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs index 02d0b6bf70a54..3c1468fbfe58c 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs @@ -22,9 +22,9 @@ internal interface IElementsMarshallingCollectionSource internal interface IElementsMarshalling { - StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context); + StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context); - StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context); + StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context); StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context); } @@ -45,7 +45,7 @@ public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax un _collectionSource = collectionSource; } - public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) { // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. @@ -73,7 +73,7 @@ public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeC Argument(destination))); } - public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) { ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context)); @@ -175,7 +175,7 @@ public NonBlittableElementsMarshalling( _collectionSource = collectionSource; } - public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context) { // If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection. // We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content. @@ -259,7 +259,7 @@ public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCod StubCodeContext.Stage.Unmarshal)); } - public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) + public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context) { // Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents, // not the array itself. @@ -356,7 +356,9 @@ public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, St VariableDeclarator( Identifier(nativeSpanIdentifier)) .WithInitializer(EqualsValueClause( - _collectionSource.GetUnmanagedValuesDestination(info, context)))))), + context.Direction == MarshalDirection.ManagedToUnmanaged + ? _collectionSource.GetUnmanagedValuesDestination(info, context) + : _collectionSource.GetUnmanagedValuesSource(info, context)))))), contentsCleanupStatements); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs index 6bf87839e5846..1790e588a8e79 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs @@ -88,7 +88,7 @@ public IMarshallingGenerator Create( return s_delegate; case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }: - if (!context.AdditionalTemporaryStateLivesAcrossStages) + if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged) { throw new MarshallingNotSupportedException(info, context); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs index 8c8aa0db66ccb..2163c71047f87 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs @@ -372,40 +372,36 @@ internal sealed class StatefulLinearCollectionMarshalling : ICustomTypeMarshalli private readonly MarshallerShape _shape; private readonly ExpressionSyntax _numElementsExpression; private readonly IElementsMarshalling _elementsMarshalling; + private readonly bool _cleanupElements; public StatefulLinearCollectionMarshalling( ICustomTypeMarshallingStrategy innerMarshaller, MarshallerShape shape, ExpressionSyntax numElementsExpression, - IElementsMarshalling elementsMarshalling) + IElementsMarshalling elementsMarshalling, + bool cleanupElements) { _innerMarshaller = innerMarshaller; _shape = shape; _numElementsExpression = numElementsExpression; _elementsMarshalling = elementsMarshalling; + _cleanupElements = cleanupElements; } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { + if (!_cleanupElements) + { + yield break; + } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) { yield return elementCleanup; } - - if (!_shape.HasFlag(MarshallerShape.Free)) - yield break; - - string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); - // .Free(); - yield return ExpressionStatement( - InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(marshaller), - IdentifierName(ShapeMemberNames.Free)), - ArgumentList())); } public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); @@ -419,9 +415,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i yield return statement; } - if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) { - yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context); yield break; } @@ -437,9 +433,10 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo { string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context); - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { - yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context); + yield break; } if (!_shape.HasFlag(MarshallerShape.ToManaged)) @@ -469,4 +466,50 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true; } + + /// + /// Marshaller that enables calling the Free method on a stateful marshaller. + /// + internal sealed class StatefulFreeMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + + public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller) + { + _innerMarshaller = innerMarshaller; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context)) + { + yield return statement; + } + + string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context); + // .Free(); + yield return ExpressionStatement( + InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(marshaller), + IdentifierName(ShapeMemberNames.Free)), + ArgumentList())); + } + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + } } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs index 94ea82a31b6d0..fd9459ca79459 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs @@ -227,7 +227,7 @@ IEnumerable GenerateCallerAllocatedBufferMarshalStatements() } else { - // = .ConvertToUnmanaged(, __buffer); + // = .ConvertToUnmanaged(, __buffer); yield return ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, @@ -295,6 +295,134 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); } + internal sealed class StatelessUnmanagedToManagedOwnershipTracking : ICustomTypeMarshallingStrategy + { + internal const string OwnOriginalValueIdentifier = "ownOriginal"; + internal const string OriginalValueIdentifier = "original"; + + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + + public StatelessUnmanagedToManagedOwnershipTracking(ICustomTypeMarshallingStrategy innerMarshaller) + { + _innerMarshaller = innerMarshaller; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context); + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context)) + { + yield return statement; + } + + // Now that we've set the new value to pass to the caller on the identifier, we need to make sure that we free the old one. + // The caller will not see the old one any more, so it won't be able to free it. + + // = true; + yield return ExpressionStatement( + AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)), + LiteralExpression(SyntaxKind.TrueLiteralExpression))); + } + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) + { + foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context)) + { + yield return statement; + } + + // bool = false; + yield return LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.BoolKeyword)), + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)), + null, + EqualsValueClause( + LiteralExpression(SyntaxKind.FalseLiteralExpression)))))); + + // = ; + yield return LocalDeclarationStatement( + VariableDeclaration( + AsNativeType(info).Syntax, + SingletonSeparatedList( + VariableDeclarator( + Identifier(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)), + null, + EqualsValueClause( + IdentifierName(context.GetIdentifiers(info).native)))))); + } + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + } + + internal sealed class FreeOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy + { + private readonly ICustomTypeMarshallingStrategy _innerMarshaller; + + public FreeOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy innerMarshaller) + { + _innerMarshaller = innerMarshaller; + } + + public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info); + + public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) + { + // if () + // { + // + // } + yield return IfStatement( + IdentifierName(context.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OwnOriginalValueIdentifier)), + Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context)))); + } + + public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context); + public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context); + + public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context); + public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context); + + public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context); + public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context); + + public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context); + + public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context); + public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context); + + private sealed record OwnedValueCodeContext(StubCodeContext InnerContext) : StubCodeContext + { + public override bool SingleFrameSpansNativeContext => InnerContext.SingleFrameSpansNativeContext; + + public override bool AdditionalTemporaryStateLivesAcrossStages => InnerContext.AdditionalTemporaryStateLivesAcrossStages; + + public override (TargetFramework framework, Version version) GetTargetFramework() => InnerContext.GetTargetFramework(); + + public override (string managed, string native) GetIdentifiers(TypePositionInfo info) + { + var (managed, _) = InnerContext.GetIdentifiers(info); + return (managed, InnerContext.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OriginalValueIdentifier)); + } + + public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => InnerContext.GetAdditionalIdentifier(info, name); + } + } + /// /// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec. /// @@ -404,7 +532,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { yield break; } @@ -535,23 +663,31 @@ internal sealed class StatelessLinearCollectionMarshalling : ICustomTypeMarshall private readonly IElementsMarshalling _elementsMarshalling; private readonly ManagedTypeInfo _unmanagedType; private readonly MarshallerShape _shape; + private readonly bool _cleanupElementsAndSpace; public StatelessLinearCollectionMarshalling( ICustomTypeMarshallingStrategy spaceMarshallingStrategy, IElementsMarshalling elementsMarshalling, ManagedTypeInfo unmanagedType, - MarshallerShape shape) + MarshallerShape shape, + bool cleanupElementsAndSpace) { _spaceMarshallingStrategy = spaceMarshallingStrategy; _elementsMarshalling = elementsMarshalling; _unmanagedType = unmanagedType; _shape = shape; + _cleanupElementsAndSpace = cleanupElementsAndSpace; } public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _unmanagedType; public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) { + if (!_cleanupElementsAndSpace) + { + yield break; + } + StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context); if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement)) @@ -576,9 +712,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer)) yield break; - if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out) { - yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context); } else { @@ -596,9 +732,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) { - if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) + if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out)) { - yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context); + yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context); yield break; } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs index 3dbd182406643..7940f6a15fd23 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs @@ -11,7 +11,7 @@ public sealed record NativeToManagedStubCodeContext : StubCodeContext { public override bool SingleFrameSpansNativeContext => false; - public override bool AdditionalTemporaryStateLivesAcrossStages => false; + public override bool AdditionalTemporaryStateLivesAcrossStages => true; private readonly TargetFramework _framework; private readonly Version _frameworkVersion; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs index 937ba3de524ef..68c7420d0b3cb 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs @@ -48,10 +48,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce() iface.SetInt(5); Assert.Equal(5, iface.GetInt()); - // https://github.com/dotnet/runtime/issues/85795 - //Assert.Equal("myName", iface.GetName()); - //iface.SetName("updated"); - //Assert.Equal("updated", iface.GetName()); + Assert.Equal("myName", iface.GetName()); + iface.SetName("updated"); + Assert.Equal("updated", iface.GetName()); var iUnknownStrategyProperty = typeof(ComObject).GetProperty("IUnknownStrategy", BindingFlags.NonPublic | BindingFlags.Instance); @@ -67,7 +66,7 @@ partial class DerivedImpl : IDerived { int data = 3; string myName = "myName"; - public void DoThingWithString([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => throw new NotImplementedException(); + public void DoThingWithString(string name) => throw new NotImplementedException(); public int GetInt() => data; diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs index e6214e759450e..20ee9b4a27668 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs @@ -2,10 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; -using ComInterfaceGenerator.Tests; +using System.Threading; using Xunit; namespace ComInterfaceGenerator.Tests @@ -17,8 +18,7 @@ internal partial class ImplicitThis [UnmanagedObjectUnwrapperAttribute>] internal partial interface INativeObject : IUnmanagedInterfaceType { - - private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2); + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6); static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation { get @@ -35,6 +35,22 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation int GetData(); [VirtualMethodIndex(1, ImplicitThisParameter = true)] void SetData(int x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData(ref int x); + [VirtualMethodIndex(3, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues))] int[] values, + int numValues, + out int oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues))] ref int[] values, + int numValues, + out int oldValue); + [VirtualMethodIndex(5, ImplicitThisParameter = true)] + void MultiplyWithData( + [MarshalUsing(CountElementName = nameof(numValues))] int[] values, + int numValues); } [NativeMarshalling(typeof(NativeObjectMarshaller))] @@ -105,16 +121,21 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed() void* wrapper = VTableGCHandlePair.Allocate(impl); - Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); - Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); - // Verify that we actually updated the managed instance. - Assert.Equal(newValue, impl.GetData()); - - VTableGCHandlePair.Free(wrapper); + try + { + Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue); + Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper)); + // Verify that we actually updated the managed instance. + Assert.Equal(newValue, impl.GetData()); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } } - class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject + sealed class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject { private int _data; @@ -123,8 +144,24 @@ public ManagedObjectImplementation(int value) _data = value; } + public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x); public int GetData() => _data; + public void MultiplyWithData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues) + { + for (int i = 0; i < values.Length; i++) + { + values[i] *= _data; + } + } public void SetData(int x) => _data = x; + public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues, out int oldValue) + { + int value = values.Sum(); + oldValue = _data; + _data = value; + } + + public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] ref int[] values, int numValues, out int oldValue) => SumAndSetData(values, numValues, out oldValue); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs new file mode 100644 index 0000000000000..224ea68c41477 --- /dev/null +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -0,0 +1,471 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.Marshalling; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using SharedTypes; +using Xunit; +using static ComInterfaceGenerator.Tests.UnmanagedToManagedCustomMarshallingTests; + +namespace ComInterfaceGenerator.Tests +{ + internal unsafe partial class NativeExportsNE + { + internal partial class UnmanagedToManagedCustomMarshalling + { + [UnmanagedObjectUnwrapper>] + internal partial interface INativeObject : IUnmanagedInterfaceType + { + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6); + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + if (s_vtable[0] == null) + { + Native.PopulateUnmanagedVirtualMethodTable(s_vtable); + } + return s_vtable; + } + } + + [VirtualMethodIndex(0, ImplicitThisParameter = true)] + [return: MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] + IntWrapper GetData(); + [VirtualMethodIndex(1, ImplicitThisParameter = true)] + void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x); + [VirtualMethodIndex(2, ImplicitThisParameter = true)] + void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data); + [VirtualMethodIndex(3, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true)] + void SumAndSetData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + [VirtualMethodIndex(5, ImplicitThisParameter = true)] + void MultiplyWithData( + [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123, + int numValues); + } + + [UnmanagedObjectUnwrapper>] + internal partial interface INativeObjectStateful : IUnmanagedInterfaceType + { + private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObjectStateful), sizeof(void*) * 6); + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + if (s_vtable[0] == null) + { + Native.PopulateUnmanagedVirtualMethodTable(s_vtable); + } + return s_vtable; + } + } + + [VirtualMethodIndex(3, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void SumAndSetData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + [VirtualMethodIndex(4, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void SumAndSetData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values, + int numValues, + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue); + + [VirtualMethodIndex(5, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)] + void MultiplyWithData( + [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))] + [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123, + int numValues); + } + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")] + public static partial void SetNativeObjectData(void* obj, int data); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_native_object_data")] + public static partial int GetNativeObjectData(void* obj); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")] + public static partial int ExchangeNativeObjectData(void* obj, ref int x); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data")] + public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues, out int oldValue); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_with_ref")] + public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] ref int[] arr, int numValues, out int oldValue); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = "multiply_with_native_object_data")] + public static partial int MultiplyWithNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues); + } + } + public class UnmanagedToManagedCustomMarshallingTests + { + [Fact] + public unsafe void ValidateOnlyByRefFreed_Stateless() + { + const int startingValue = 13; + const int newValue = 42; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + NativeExportsNE.UnmanagedToManagedCustomMarshalling.GetNativeObjectData(wrapper); + + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SetNativeObjectData(wrapper, newValue); + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + + int finalValue = 10; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.ExchangeNativeObjectData(wrapper, ref finalValue); + Assert.Equal(freeCalls + 1, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateless() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _); + + Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsByRefFreed_Stateless() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _); + + Assert.Equal(freeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")] + public unsafe void ValidateArrayElementsByValueOutFreed_Stateless() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + var expected = values.Select(x => x * startingValue).ToArray(); + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); + + Assert.Equal(expected, values); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _); + + // We shouldn't free the elements, but we always free the stateful marshaller. + Assert.Equal(elementFreeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + public unsafe void ValidateArrayElementsByRefFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")] + public unsafe void ValidateArrayElementsByValueOutFreed_Stateful() + { + const int startingValue = 13; + + ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue); + + void* wrapper = VTableGCHandlePair.Allocate(impl); + + try + { + var values = new int[] { 1, 32, 63, 124, 255 }; + var expected = values.Select(x => x * startingValue).ToArray(); + + int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree; + int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree; + + NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length); + + Assert.Equal(expected, values); + + Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree); + Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree); + } + finally + { + VTableGCHandlePair.Free(wrapper); + } + } + + sealed unsafe class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject, NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObjectStateful + { + private IntWrapper _data; + + public ManagedObjectImplementation(int value) + { + _data = new() { i = value }; + } + + public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x); + public IntWrapper GetData() => _data; + public void MultiplyWithData(IntWrapper[] values, int numValues) + { + for (int i = 0; i < values.Length; i++) + { + values[i].i *= _data.i; + } + } + public void SetData(IntWrapper x) => _data = x; + public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) => SumAndSetData(values, numValues, out oldValue); + public void SumAndSetData(IntWrapper[] values, int numValues, out IntWrapper oldValue) + { + int value = values.Sum(value => value.i); + oldValue = _data; + _data = new() { i = value }; + } + + static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation + { + get + { + Assert.Fail("The VirtualMethodTableManagedImplementation property should not be called on implementing class types"); + return null; + } + } + } + + [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))] + public static unsafe class IntWrapperMarshallerToIntWithFreeCounts + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + public static int ConvertToUnmanaged(IntWrapper managed) + { + return managed.i; + } + + public static IntWrapper ConvertToManaged(int unmanaged) + { + return new IntWrapper { i = unmanaged }; + } + + public static void Free(int _) + { + NumCallsToFree++; + } + } + + [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedIn, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.In))] + [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedRef, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.Ref))] + [ContiguousCollectionMarshaller] + public unsafe static class StatefulUnmanagedToManagedCollectionMarshaller + where TUnmanaged : unmanaged + { + public struct In + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + private TUnmanaged* _unmanaged; + private TManaged[] _managed; + + public void FromUnmanaged(TUnmanaged* unmanaged) + { + _unmanaged = unmanaged; + } + + public Span GetManagedValuesDestination(int numElements) + { + return _managed = new TManaged[numElements]; + } + + public ReadOnlySpan GetUnmanagedValuesSource(int numElements) + { + return new(_unmanaged, numElements); + } + + public TManaged[] ToManaged() + { + return _managed; + } + + public void Free() + { + NumCallsToFree++; + } + } + + public struct Ref + { + [ThreadStatic] + public static int NumCallsToFree = 0; + + private TUnmanaged* _originalUnmanaged; + private TUnmanaged* _unmanaged; + private TManaged[] _managed; + + public void FromUnmanaged(TUnmanaged* unmanaged) + { + _originalUnmanaged = unmanaged; + } + + public Span GetManagedValuesDestination(int numElements) + { + return _managed = new TManaged[numElements]; + } + + public ReadOnlySpan GetUnmanagedValuesSource(int numElements) + { + return new(_originalUnmanaged, numElements); + } + + public TManaged[] ToManaged() + { + return _managed; + } + + public void Free() + { + Marshal.FreeCoTaskMem((nint)_originalUnmanaged); + NumCallsToFree++; + } + + public void FromManaged(TManaged[] managed) + { + _managed = managed; + } + + public TUnmanaged* ToUnmanaged() + { + return _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length); + } + + public ReadOnlySpan GetManagedValuesSource() + { + return _managed; + } + + public Span GetUnmanagedValuesDestination() + { + return new(_unmanaged, _managed.Length); + } + } + } + } +} diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs index 89c1e84fd5ab8..628db42618719 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs @@ -327,7 +327,6 @@ public static IEnumerable CustomCollectionsManagedToUnmanaged(Generato [MemberData(nameof(UnmanagedToManagedCodeSnippetsToCompile), GeneratorKind.VTableIndexStubGenerator)] [MemberData(nameof(CustomCollectionsManagedToUnmanaged), GeneratorKind.VTableIndexStubGenerator)] [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)] - [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)] public async Task ValidateVTableIndexSnippets(string id, string source) { _ = id; diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs index a286dd4de3614..a8e6458fb0ab5 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/VirtualMethodTables.cs @@ -8,6 +8,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace NativeExports @@ -48,6 +49,10 @@ public struct VirtualFunctionTable { public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; + public delegate* unmanaged sumAndSetData; + public delegate* unmanaged sumAndSetDataWithRef; + public delegate* unmanaged multiplyWithData; } public readonly VirtualFunctionTable* VTable; @@ -66,12 +71,20 @@ public struct VirtualFunctionTable // The order of functions here should match NativeObjectInterface.VirtualFunctionTable's members. public delegate* unmanaged getData; public delegate* unmanaged setData; + public delegate* unmanaged exchangeData; + public delegate* unmanaged sumAndSetData; + public delegate* unmanaged sumAndSetDataWithRef; + public delegate* unmanaged multiplyWithData; } static NativeObject() { VTablePointer = (VirtualFunctionTable*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(NativeObject), sizeof(VirtualFunctionTable)); VTablePointer->getData = &GetData; VTablePointer->setData = &SetData; + VTablePointer->exchangeData = &ExchangeData; + VTablePointer->sumAndSetData = &SumAndSetData; + VTablePointer->sumAndSetDataWithRef = &SumAndSetData; + VTablePointer->multiplyWithData = &MultiplyWithData; } private static readonly VirtualFunctionTable* VTablePointer; @@ -95,6 +108,52 @@ private static void SetData(NativeObject* obj, int value) { obj->Data = value; } + + [UnmanagedCallersOnly] + private static void ExchangeData(NativeObject* obj, int* value) + { + var temp = obj->Data; + obj->Data = *value; + *value = temp; + } + + [UnmanagedCallersOnly] + private static void SumAndSetData(NativeObject* obj, int** values, int numValues, int* oldValue) + { + *oldValue = obj->Data; + + Span arr = new(*values, numValues); + int sum = 0; + foreach (int value in arr) + { + sum += value; + } + obj->Data = sum; + } + + [UnmanagedCallersOnly] + private static void SumAndSetData(NativeObject* obj, int* values, int numValues, int* oldValue) + { + *oldValue = obj->Data; + + Span arr = new(values, numValues); + int sum = 0; + foreach (int value in arr) + { + sum += value; + } + obj->Data = sum; + } + + [UnmanagedCallersOnly] + private static void MultiplyWithData(NativeObject* obj, int* values, int numValues) + { + Span arr = new(values, numValues); + foreach (ref int value in arr) + { + value *= obj->Data; + } + } } [UnmanagedCallersOnly(EntryPoint = "new_native_object")] @@ -127,5 +186,33 @@ public static int GetNativeObjectData([DNNE.C99Type("struct INativeObject*")] Na { return obj->VTable->getData(obj); } + + [UnmanagedCallersOnly(EntryPoint = "exchange_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void ExchangeNativeObjectData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* x) + { + obj->VTable->exchangeData(obj, x); + } + + [UnmanagedCallersOnly(EntryPoint = "sum_and_set_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void SumAndSetData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues, int* oldValue) + { + obj->VTable->sumAndSetData(obj, values, numValues, oldValue); + } + + [UnmanagedCallersOnly(EntryPoint = "sum_and_set_native_object_data_with_ref")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void SumAndSetDataWithRef([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int** values, int numValues, int* oldValue) + { + obj->VTable->sumAndSetDataWithRef(obj, values, numValues, oldValue); + } + + [UnmanagedCallersOnly(EntryPoint = "multiply_with_native_object_data")] + [DNNE.C99DeclCode("struct INativeObject;")] + public static void MultiplyWithData([DNNE.C99Type("struct INativeObject*")] NativeObjectInterface* obj, int* values, int numValues) + { + obj->VTable->multiplyWithData(obj, values, numValues); + } } }