diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs
index afa27b08e0e6d..a11ce0ecefdef 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/LinearCollectionElementMarshallingCodeContext.cs
@@ -34,6 +34,7 @@ public LinearCollectionElementMarshallingCodeContext(
_managedSpanIdentifier = managedSpanIdentifier;
_nativeSpanIdentifier = nativeSpanIdentifier;
ParentContext = parentContext;
+ Direction = ParentContext.Direction;
}
public override (TargetFramework framework, Version version) GetTargetFramework()
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs
index 03b90462ea301..ee9809770583e 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/AttributedMarshallingModelGeneratorFactory.cs
@@ -244,8 +244,22 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false);
- if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
+ FreeStrategy freeStrategy = GetFreeStrategy(info, context);
+
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+ }
+
+ if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
+ {
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
+ }
+
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+ }
}
IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
@@ -311,19 +325,42 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax);
}
+ FreeStrategy freeStrategy = GetFreeStrategy(info, context);
IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource();
IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
- marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling);
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+ }
+
+ marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree);
+
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+ }
+
+ if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
+ {
+ marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy);
+ }
}
else
{
marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression);
+ FreeStrategy freeStrategy = GetFreeStrategy(info, context);
+
IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax);
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
+ }
+
IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);
- marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape);
+ marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree);
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
{
@@ -334,8 +371,15 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true);
}
- if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
+ if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
+ {
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax);
+ }
+
+ if (freeStrategy == FreeStrategy.FreeOriginal)
+ {
+ marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
+ }
}
IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(
@@ -351,6 +395,48 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
return marshallingGenerator;
}
+ private enum FreeStrategy
+ {
+ ///
+ /// Free the unmanaged value stored in the native identifier.
+ ///
+ FreeNative,
+ ///
+ /// Free the unmanaged value originally passed into the stub.
+ ///
+ FreeOriginal,
+ ///
+ /// Do not free the unmanaged value, we don't own it.
+ ///
+ NoFree
+ }
+
+ private static FreeStrategy GetFreeStrategy(TypePositionInfo info, StubCodeContext context)
+ {
+ // When marshalling from managed to unmanaged, we always own the value in the native identifier.
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged)
+ {
+ return FreeStrategy.FreeNative;
+ }
+
+ // When we're in a case where we don't have state across stages, the parent stub context that can track the state
+ // will only call our Cleanup stage when we own the value in the native identifier.
+ if (!context.AdditionalTemporaryStateLivesAcrossStages)
+ {
+ return FreeStrategy.FreeNative;
+ }
+
+ // In an unmanaged-to-managed stub where a value is passed by 'ref',
+ // we own the original value once we replace it with the new value we're passing out to the caller.
+ if (info.RefKind == RefKind.Ref)
+ {
+ return FreeStrategy.FreeOriginal;
+ }
+
+ // In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'.
+ return FreeStrategy.NoFree;
+ }
+
private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
{
IElementsMarshalling elementsMarshalling;
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs
index 02d0b6bf70a54..3c1468fbfe58c 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/ElementsMarshalling.cs
@@ -22,9 +22,9 @@ internal interface IElementsMarshallingCollectionSource
internal interface IElementsMarshalling
{
- StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
+ StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context);
- StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
+ StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context);
}
@@ -45,7 +45,7 @@ public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax un
_collectionSource = collectionSource;
}
- public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
+ public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
@@ -73,7 +73,7 @@ public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeC
Argument(destination)));
}
- public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+ public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context));
@@ -175,7 +175,7 @@ public NonBlittableElementsMarshalling(
_collectionSource = collectionSource;
}
- public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
+ public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
@@ -259,7 +259,7 @@ public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCod
StubCodeContext.Stage.Unmarshal));
}
- public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
+ public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents,
// not the array itself.
@@ -356,7 +356,9 @@ public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, St
VariableDeclarator(
Identifier(nativeSpanIdentifier))
.WithInitializer(EqualsValueClause(
- _collectionSource.GetUnmanagedValuesDestination(info, context)))))),
+ context.Direction == MarshalDirection.ManagedToUnmanaged
+ ? _collectionSource.GetUnmanagedValuesDestination(info, context)
+ : _collectionSource.GetUnmanagedValuesSource(info, context)))))),
contentsCleanupStatements);
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs
index 6bf87839e5846..1790e588a8e79 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/MarshalAsMarshallingGeneratorFactory.cs
@@ -88,7 +88,7 @@ public IMarshallingGenerator Create(
return s_delegate;
case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }:
- if (!context.AdditionalTemporaryStateLivesAcrossStages)
+ if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged)
{
throw new MarshallingNotSupportedException(info, context);
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs
index 8c8aa0db66ccb..2163c71047f87 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatefulMarshallingStrategy.cs
@@ -372,40 +372,36 @@ internal sealed class StatefulLinearCollectionMarshalling : ICustomTypeMarshalli
private readonly MarshallerShape _shape;
private readonly ExpressionSyntax _numElementsExpression;
private readonly IElementsMarshalling _elementsMarshalling;
+ private readonly bool _cleanupElements;
public StatefulLinearCollectionMarshalling(
ICustomTypeMarshallingStrategy innerMarshaller,
MarshallerShape shape,
ExpressionSyntax numElementsExpression,
- IElementsMarshalling elementsMarshalling)
+ IElementsMarshalling elementsMarshalling,
+ bool cleanupElements)
{
_innerMarshaller = innerMarshaller;
_shape = shape;
_numElementsExpression = numElementsExpression;
_elementsMarshalling = elementsMarshalling;
+ _cleanupElements = cleanupElements;
}
public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
+ if (!_cleanupElements)
+ {
+ yield break;
+ }
+
StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}
-
- if (!_shape.HasFlag(MarshallerShape.Free))
- yield break;
-
- string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
- // .Free();
- yield return ExpressionStatement(
- InvocationExpression(
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName(marshaller),
- IdentifierName(ShapeMemberNames.Free)),
- ArgumentList()));
}
public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
@@ -419,9 +415,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i
yield return statement;
}
- if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
{
- yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context);
+ yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
yield break;
}
@@ -437,9 +433,10 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);
- if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
- yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context);
+ yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
+ yield break;
}
if (!_shape.HasFlag(MarshallerShape.ToManaged))
@@ -469,4 +466,50 @@ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo
public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
}
+
+ ///
+ /// Marshaller that enables calling the Free method on a stateful marshaller.
+ ///
+ internal sealed class StatefulFreeMarshalling : ICustomTypeMarshallingStrategy
+ {
+ private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
+
+ public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
+ {
+ _innerMarshaller = innerMarshaller;
+ }
+
+ public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
+
+ public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+ {
+ foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context))
+ {
+ yield return statement;
+ }
+
+ string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
+ // .Free();
+ yield return ExpressionStatement(
+ InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName(marshaller),
+ IdentifierName(ShapeMemberNames.Free)),
+ ArgumentList()));
+ }
+ public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
+
+ public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);
+
+ public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
+ public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
+ public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
+ public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);
+
+ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
+
+ public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
+
+ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
+ }
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs
index 94ea82a31b6d0..fd9459ca79459 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/StatelessMarshallingStrategy.cs
@@ -227,7 +227,7 @@ IEnumerable GenerateCallerAllocatedBufferMarshalStatements()
}
else
{
- // = .ConvertToUnmanaged(, __buffer);
+ // = .ConvertToUnmanaged(, __buffer);
yield return ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
@@ -295,6 +295,134 @@ public IEnumerable GenerateCleanupStatements(TypePositionInfo i
public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
}
+ internal sealed class StatelessUnmanagedToManagedOwnershipTracking : ICustomTypeMarshallingStrategy
+ {
+ internal const string OwnOriginalValueIdentifier = "ownOriginal";
+ internal const string OriginalValueIdentifier = "original";
+
+ private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
+
+ public StatelessUnmanagedToManagedOwnershipTracking(ICustomTypeMarshallingStrategy innerMarshaller)
+ {
+ _innerMarshaller = innerMarshaller;
+ }
+
+ public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
+
+ public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateCleanupStatements(info, context);
+
+ public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
+ public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context)
+ {
+ foreach (StatementSyntax statement in _innerMarshaller.GenerateMarshalStatements(info, context))
+ {
+ yield return statement;
+ }
+
+ // Now that we've set the new value to pass to the caller on the identifier, we need to make sure that we free the old one.
+ // The caller will not see the old one any more, so it won't be able to free it.
+
+ // = true;
+ yield return ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)),
+ LiteralExpression(SyntaxKind.TrueLiteralExpression)));
+ }
+
+ public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
+ public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
+
+ public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
+ public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context)
+ {
+ foreach (StatementSyntax statement in _innerMarshaller.GenerateSetupStatements(info, context))
+ {
+ yield return statement;
+ }
+
+ // bool = false;
+ yield return LocalDeclarationStatement(
+ VariableDeclaration(
+ PredefinedType(Token(SyntaxKind.BoolKeyword)),
+ SingletonSeparatedList(
+ VariableDeclarator(
+ Identifier(context.GetAdditionalIdentifier(info, OwnOriginalValueIdentifier)),
+ null,
+ EqualsValueClause(
+ LiteralExpression(SyntaxKind.FalseLiteralExpression))))));
+
+ // = ;
+ yield return LocalDeclarationStatement(
+ VariableDeclaration(
+ AsNativeType(info).Syntax,
+ SingletonSeparatedList(
+ VariableDeclarator(
+ Identifier(context.GetAdditionalIdentifier(info, OriginalValueIdentifier)),
+ null,
+ EqualsValueClause(
+ IdentifierName(context.GetIdentifiers(info).native))))));
+ }
+
+ public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
+
+ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
+ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
+ }
+
+ internal sealed class FreeOwnedOriginalValueMarshalling : ICustomTypeMarshallingStrategy
+ {
+ private readonly ICustomTypeMarshallingStrategy _innerMarshaller;
+
+ public FreeOwnedOriginalValueMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
+ {
+ _innerMarshaller = innerMarshaller;
+ }
+
+ public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
+
+ public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
+ {
+ // if ()
+ // {
+ //
+ // }
+ yield return IfStatement(
+ IdentifierName(context.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OwnOriginalValueIdentifier)),
+ Block(_innerMarshaller.GenerateCleanupStatements(info, new OwnedValueCodeContext(context))));
+ }
+
+ public IEnumerable GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);
+ public IEnumerable GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);
+
+ public IEnumerable GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
+ public IEnumerable GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
+
+ public IEnumerable GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
+ public IEnumerable GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);
+
+ public IEnumerable GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);
+
+ public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);
+ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
+
+ private sealed record OwnedValueCodeContext(StubCodeContext InnerContext) : StubCodeContext
+ {
+ public override bool SingleFrameSpansNativeContext => InnerContext.SingleFrameSpansNativeContext;
+
+ public override bool AdditionalTemporaryStateLivesAcrossStages => InnerContext.AdditionalTemporaryStateLivesAcrossStages;
+
+ public override (TargetFramework framework, Version version) GetTargetFramework() => InnerContext.GetTargetFramework();
+
+ public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
+ {
+ var (managed, _) = InnerContext.GetIdentifiers(info);
+ return (managed, InnerContext.GetAdditionalIdentifier(info, StatelessUnmanagedToManagedOwnershipTracking.OriginalValueIdentifier));
+ }
+
+ public override string GetAdditionalIdentifier(TypePositionInfo info, string name) => InnerContext.GetAdditionalIdentifier(info, name);
+ }
+ }
+
///
/// Marshaller type that enables allocating space for marshalling a linear collection using a marshaller that implements the LinearCollection marshalling spec.
///
@@ -404,7 +532,7 @@ public IEnumerable GenerateSetupStatements(TypePositionInfo inf
public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
{
- if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
yield break;
}
@@ -535,23 +663,31 @@ internal sealed class StatelessLinearCollectionMarshalling : ICustomTypeMarshall
private readonly IElementsMarshalling _elementsMarshalling;
private readonly ManagedTypeInfo _unmanagedType;
private readonly MarshallerShape _shape;
+ private readonly bool _cleanupElementsAndSpace;
public StatelessLinearCollectionMarshalling(
ICustomTypeMarshallingStrategy spaceMarshallingStrategy,
IElementsMarshalling elementsMarshalling,
ManagedTypeInfo unmanagedType,
- MarshallerShape shape)
+ MarshallerShape shape,
+ bool cleanupElementsAndSpace)
{
_spaceMarshallingStrategy = spaceMarshallingStrategy;
_elementsMarshalling = elementsMarshalling;
_unmanagedType = unmanagedType;
_shape = shape;
+ _cleanupElementsAndSpace = cleanupElementsAndSpace;
}
public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _unmanagedType;
public IEnumerable GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
+ if (!_cleanupElementsAndSpace)
+ {
+ yield break;
+ }
+
StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);
if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
@@ -576,9 +712,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i
if (!_shape.HasFlag(MarshallerShape.ToUnmanaged) && !_shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
yield break;
- if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
{
- yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context);
+ yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
}
else
{
@@ -596,9 +732,9 @@ public IEnumerable GenerateMarshalStatements(TypePositionInfo i
public IEnumerable GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context)
{
- if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
+ if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
- yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context);
+ yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
yield break;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs
index 3dbd182406643..7940f6a15fd23 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/NativeToManagedStubCodeContext.cs
@@ -11,7 +11,7 @@ public sealed record NativeToManagedStubCodeContext : StubCodeContext
{
public override bool SingleFrameSpansNativeContext => false;
- public override bool AdditionalTemporaryStateLivesAcrossStages => false;
+ public override bool AdditionalTemporaryStateLivesAcrossStages => true;
private readonly TargetFramework _framework;
private readonly Version _frameworkVersion;
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs
index 937ba3de524ef..68c7420d0b3cb 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IDerivedTests.cs
@@ -48,10 +48,9 @@ public unsafe void CallBaseInterfaceMethod_EnsureQiCalledOnce()
iface.SetInt(5);
Assert.Equal(5, iface.GetInt());
- // https://github.com/dotnet/runtime/issues/85795
- //Assert.Equal("myName", iface.GetName());
- //iface.SetName("updated");
- //Assert.Equal("updated", iface.GetName());
+ Assert.Equal("myName", iface.GetName());
+ iface.SetName("updated");
+ Assert.Equal("updated", iface.GetName());
var iUnknownStrategyProperty = typeof(ComObject).GetProperty("IUnknownStrategy", BindingFlags.NonPublic | BindingFlags.Instance);
@@ -67,7 +66,7 @@ partial class DerivedImpl : IDerived
{
int data = 3;
string myName = "myName";
- public void DoThingWithString([MarshalUsing(typeof(Utf16StringMarshaller))] string name) => throw new NotImplementedException();
+ public void DoThingWithString(string name) => throw new NotImplementedException();
public int GetInt() => data;
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs
index e6214e759450e..20ee9b4a27668 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ImplicitThisTests.cs
@@ -2,10 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
-using ComInterfaceGenerator.Tests;
+using System.Threading;
using Xunit;
namespace ComInterfaceGenerator.Tests
@@ -17,8 +18,7 @@ internal partial class ImplicitThis
[UnmanagedObjectUnwrapperAttribute>]
internal partial interface INativeObject : IUnmanagedInterfaceType
{
-
- private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 2);
+ private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6);
static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
{
get
@@ -35,6 +35,22 @@ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
int GetData();
[VirtualMethodIndex(1, ImplicitThisParameter = true)]
void SetData(int x);
+ [VirtualMethodIndex(2, ImplicitThisParameter = true)]
+ void ExchangeData(ref int x);
+ [VirtualMethodIndex(3, ImplicitThisParameter = true)]
+ void SumAndSetData(
+ [MarshalUsing(CountElementName = nameof(numValues))] int[] values,
+ int numValues,
+ out int oldValue);
+ [VirtualMethodIndex(4, ImplicitThisParameter = true)]
+ void SumAndSetData(
+ [MarshalUsing(CountElementName = nameof(numValues))] ref int[] values,
+ int numValues,
+ out int oldValue);
+ [VirtualMethodIndex(5, ImplicitThisParameter = true)]
+ void MultiplyWithData(
+ [MarshalUsing(CountElementName = nameof(numValues))] int[] values,
+ int numValues);
}
[NativeMarshalling(typeof(NativeObjectMarshaller))]
@@ -105,16 +121,21 @@ public unsafe void ValidateImplicitThisUnmanagedToManagedFunctionCallsSucceed()
void* wrapper = VTableGCHandlePair.Allocate(impl);
- Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
- NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue);
- Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
- // Verify that we actually updated the managed instance.
- Assert.Equal(newValue, impl.GetData());
-
- VTableGCHandlePair.Free(wrapper);
+ try
+ {
+ Assert.Equal(startingValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
+ NativeExportsNE.ImplicitThis.SetNativeObjectData(wrapper, newValue);
+ Assert.Equal(newValue, NativeExportsNE.ImplicitThis.GetNativeObjectData(wrapper));
+ // Verify that we actually updated the managed instance.
+ Assert.Equal(newValue, impl.GetData());
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
}
- class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject
+ sealed class ManagedObjectImplementation : NativeExportsNE.ImplicitThis.INativeObject
{
private int _data;
@@ -123,8 +144,24 @@ public ManagedObjectImplementation(int value)
_data = value;
}
+ public void ExchangeData(ref int x) => x = Interlocked.Exchange(ref _data, x);
public int GetData() => _data;
+ public void MultiplyWithData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues)
+ {
+ for (int i = 0; i < values.Length; i++)
+ {
+ values[i] *= _data;
+ }
+ }
public void SetData(int x) => _data = x;
+ public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] int[] values, int numValues, out int oldValue)
+ {
+ int value = values.Sum();
+ oldValue = _data;
+ _data = value;
+ }
+
+ public void SumAndSetData([MarshalUsing(CountElementName = "numValues")] ref int[] values, int numValues, out int oldValue) => SumAndSetData(values, numValues, out oldValue);
}
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs
new file mode 100644
index 0000000000000..224ea68c41477
--- /dev/null
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs
@@ -0,0 +1,471 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using SharedTypes;
+using Xunit;
+using static ComInterfaceGenerator.Tests.UnmanagedToManagedCustomMarshallingTests;
+
+namespace ComInterfaceGenerator.Tests
+{
+ internal unsafe partial class NativeExportsNE
+ {
+ internal partial class UnmanagedToManagedCustomMarshalling
+ {
+ [UnmanagedObjectUnwrapper>]
+ internal partial interface INativeObject : IUnmanagedInterfaceType
+ {
+ private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObject), sizeof(void*) * 6);
+ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
+ {
+ get
+ {
+ if (s_vtable[0] == null)
+ {
+ Native.PopulateUnmanagedVirtualMethodTable(s_vtable);
+ }
+ return s_vtable;
+ }
+ }
+
+ [VirtualMethodIndex(0, ImplicitThisParameter = true)]
+ [return: MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))]
+ IntWrapper GetData();
+ [VirtualMethodIndex(1, ImplicitThisParameter = true)]
+ void SetData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] IntWrapper x);
+ [VirtualMethodIndex(2, ImplicitThisParameter = true)]
+ void ExchangeData([MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] ref IntWrapper data);
+ [VirtualMethodIndex(3, ImplicitThisParameter = true)]
+ void SumAndSetData(
+ [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values,
+ int numValues,
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue);
+ [VirtualMethodIndex(4, ImplicitThisParameter = true)]
+ void SumAndSetData(
+ [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values,
+ int numValues,
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue);
+ [VirtualMethodIndex(5, ImplicitThisParameter = true)]
+ void MultiplyWithData(
+ [MarshalUsing(CountElementName = nameof(numValues)), MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123,
+ int numValues);
+ }
+
+ [UnmanagedObjectUnwrapper>]
+ internal partial interface INativeObjectStateful : IUnmanagedInterfaceType
+ {
+ private static void** s_vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(INativeObjectStateful), sizeof(void*) * 6);
+ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
+ {
+ get
+ {
+ if (s_vtable[0] == null)
+ {
+ Native.PopulateUnmanagedVirtualMethodTable(s_vtable);
+ }
+ return s_vtable;
+ }
+ }
+
+ [VirtualMethodIndex(3, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)]
+ void SumAndSetData(
+ [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))]
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] IntWrapper[] values,
+ int numValues,
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue);
+ [VirtualMethodIndex(4, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)]
+ void SumAndSetData(
+ [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))]
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1)] ref IntWrapper[] values,
+ int numValues,
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts))] out IntWrapper oldValue);
+
+ [VirtualMethodIndex(5, ImplicitThisParameter = true, Direction = MarshalDirection.UnmanagedToManaged)]
+ void MultiplyWithData(
+ [MarshalUsing(typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>), CountElementName = nameof(numValues))]
+ [MarshalUsing(typeof(IntWrapperMarshallerToIntWithFreeCounts), ElementIndirectionDepth = 1), In, Out] IntWrapper[] values123,
+ int numValues);
+ }
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_native_object_data")]
+ public static partial void SetNativeObjectData(void* obj, int data);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_native_object_data")]
+ public static partial int GetNativeObjectData(void* obj);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "exchange_native_object_data")]
+ public static partial int ExchangeNativeObjectData(void* obj, ref int x);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data")]
+ public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues, out int oldValue);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_and_set_native_object_data_with_ref")]
+ public static partial int SumAndSetNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] ref int[] arr, int numValues, out int oldValue);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "multiply_with_native_object_data")]
+ public static partial int MultiplyWithNativeObjectData(void* obj, [MarshalUsing(CountElementName = nameof(numValues))] int[] arr, int numValues);
+ }
+ }
+ public class UnmanagedToManagedCustomMarshallingTests
+ {
+ [Fact]
+ public unsafe void ValidateOnlyByRefFreed_Stateless()
+ {
+ const int startingValue = 13;
+ const int newValue = 42;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.GetNativeObjectData(wrapper);
+
+ Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.SetNativeObjectData(wrapper, newValue);
+ Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+
+ int finalValue = 10;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.ExchangeNativeObjectData(wrapper, ref finalValue);
+ Assert.Equal(freeCalls + 1, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateless()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+
+ int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _);
+
+ Assert.Equal(freeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ public unsafe void ValidateArrayElementsByRefFreed_Stateless()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+
+ int freeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _);
+
+ Assert.Equal(freeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")]
+ public unsafe void ValidateArrayElementsByValueOutFreed_Stateless()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+ var expected = values.Select(x => x * startingValue).ToArray();
+
+ int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length);
+
+ Assert.Equal(expected, values);
+
+ Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ public unsafe void ValidateArrayElementsAndOutParameterNotFreed_Stateful()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+
+ int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+ int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, values, values.Length, out int _);
+
+ // We shouldn't free the elements, but we always free the stateful marshaller.
+ Assert.Equal(elementFreeCalls, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ public unsafe void ValidateArrayElementsByRefFreed_Stateful()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+
+ int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+ int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.SumAndSetNativeObjectData(wrapper, ref values, values.Length, out int _);
+
+ Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.Ref.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ [Fact]
+ [ActiveIssue("https://github.com/dotnet/runtime/issues/86608")]
+ public unsafe void ValidateArrayElementsByValueOutFreed_Stateful()
+ {
+ const int startingValue = 13;
+
+ ManagedObjectImplementation impl = new ManagedObjectImplementation(startingValue);
+
+ void* wrapper = VTableGCHandlePair.Allocate(impl);
+
+ try
+ {
+ var values = new int[] { 1, 32, 63, 124, 255 };
+ var expected = values.Select(x => x * startingValue).ToArray();
+
+ int elementFreeCalls = IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree;
+ int marshallerFreeCalls = StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree;
+
+ NativeExportsNE.UnmanagedToManagedCustomMarshalling.MultiplyWithNativeObjectData(wrapper, values, values.Length);
+
+ Assert.Equal(expected, values);
+
+ Assert.Equal(elementFreeCalls + values.Length, IntWrapperMarshallerToIntWithFreeCounts.NumCallsToFree);
+ Assert.Equal(marshallerFreeCalls + 1, StatefulUnmanagedToManagedCollectionMarshaller.In.NumCallsToFree);
+ }
+ finally
+ {
+ VTableGCHandlePair.Free(wrapper);
+ }
+ }
+
+ sealed unsafe class ManagedObjectImplementation : NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObject, NativeExportsNE.UnmanagedToManagedCustomMarshalling.INativeObjectStateful
+ {
+ private IntWrapper _data;
+
+ public ManagedObjectImplementation(int value)
+ {
+ _data = new() { i = value };
+ }
+
+ public void ExchangeData(ref IntWrapper x) => x = Interlocked.Exchange(ref _data, x);
+ public IntWrapper GetData() => _data;
+ public void MultiplyWithData(IntWrapper[] values, int numValues)
+ {
+ for (int i = 0; i < values.Length; i++)
+ {
+ values[i].i *= _data.i;
+ }
+ }
+ public void SetData(IntWrapper x) => _data = x;
+ public void SumAndSetData(ref IntWrapper[] values, int numValues, out IntWrapper oldValue) => SumAndSetData(values, numValues, out oldValue);
+ public void SumAndSetData(IntWrapper[] values, int numValues, out IntWrapper oldValue)
+ {
+ int value = values.Sum(value => value.i);
+ oldValue = _data;
+ _data = new() { i = value };
+ }
+
+ static void* IUnmanagedInterfaceType.VirtualMethodTableManagedImplementation
+ {
+ get
+ {
+ Assert.Fail("The VirtualMethodTableManagedImplementation property should not be called on implementing class types");
+ return null;
+ }
+ }
+ }
+
+ [CustomMarshaller(typeof(IntWrapper), MarshalMode.Default, typeof(IntWrapperMarshallerToIntWithFreeCounts))]
+ public static unsafe class IntWrapperMarshallerToIntWithFreeCounts
+ {
+ [ThreadStatic]
+ public static int NumCallsToFree = 0;
+
+ public static int ConvertToUnmanaged(IntWrapper managed)
+ {
+ return managed.i;
+ }
+
+ public static IntWrapper ConvertToManaged(int unmanaged)
+ {
+ return new IntWrapper { i = unmanaged };
+ }
+
+ public static void Free(int _)
+ {
+ NumCallsToFree++;
+ }
+ }
+
+ [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedIn, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.In))]
+ [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder[]), MarshalMode.UnmanagedToManagedRef, typeof(StatefulUnmanagedToManagedCollectionMarshaller<,>.Ref))]
+ [ContiguousCollectionMarshaller]
+ public unsafe static class StatefulUnmanagedToManagedCollectionMarshaller
+ where TUnmanaged : unmanaged
+ {
+ public struct In
+ {
+ [ThreadStatic]
+ public static int NumCallsToFree = 0;
+
+ private TUnmanaged* _unmanaged;
+ private TManaged[] _managed;
+
+ public void FromUnmanaged(TUnmanaged* unmanaged)
+ {
+ _unmanaged = unmanaged;
+ }
+
+ public Span GetManagedValuesDestination(int numElements)
+ {
+ return _managed = new TManaged[numElements];
+ }
+
+ public ReadOnlySpan GetUnmanagedValuesSource(int numElements)
+ {
+ return new(_unmanaged, numElements);
+ }
+
+ public TManaged[] ToManaged()
+ {
+ return _managed;
+ }
+
+ public void Free()
+ {
+ NumCallsToFree++;
+ }
+ }
+
+ public struct Ref
+ {
+ [ThreadStatic]
+ public static int NumCallsToFree = 0;
+
+ private TUnmanaged* _originalUnmanaged;
+ private TUnmanaged* _unmanaged;
+ private TManaged[] _managed;
+
+ public void FromUnmanaged(TUnmanaged* unmanaged)
+ {
+ _originalUnmanaged = unmanaged;
+ }
+
+ public Span GetManagedValuesDestination(int numElements)
+ {
+ return _managed = new TManaged[numElements];
+ }
+
+ public ReadOnlySpan GetUnmanagedValuesSource(int numElements)
+ {
+ return new(_originalUnmanaged, numElements);
+ }
+
+ public TManaged[] ToManaged()
+ {
+ return _managed;
+ }
+
+ public void Free()
+ {
+ Marshal.FreeCoTaskMem((nint)_originalUnmanaged);
+ NumCallsToFree++;
+ }
+
+ public void FromManaged(TManaged[] managed)
+ {
+ _managed = managed;
+ }
+
+ public TUnmanaged* ToUnmanaged()
+ {
+ return _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length);
+ }
+
+ public ReadOnlySpan GetManagedValuesSource()
+ {
+ return _managed;
+ }
+
+ public Span GetUnmanagedValuesDestination()
+ {
+ return new(_unmanaged, _managed.Length);
+ }
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs
index 89c1e84fd5ab8..628db42618719 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs
@@ -327,7 +327,6 @@ public static IEnumerable