Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only call Free in unmanaged->managed stubs when ownership has been transfered to the callee #86415

Merged
merged 9 commits into from
May 23, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public LinearCollectionElementMarshallingCodeContext(
_managedSpanIdentifier = managedSpanIdentifier;
_nativeSpanIdentifier = nativeSpanIdentifier;
ParentContext = parentContext;
Direction = ParentContext.Direction;
}

public override (TargetFramework framework, Version version) GetTargetFramework()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,22 @@ private IMarshallingGenerator CreateCustomNativeTypeMarshaller(TypePositionInfo
if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax, marshallerData.BufferElementType.Syntax, isLinearCollectionMarshalling: false);

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
FreeStrategy freeStrategy = GetFreeStrategy(info, context);

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerData.MarshallerType.Syntax);
}

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}
}

IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
Expand Down Expand Up @@ -311,19 +325,42 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatefulCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax);
}

FreeStrategy freeStrategy = GetFreeStrategy(info, context);
IElementsMarshallingCollectionSource collectionSource = new StatefulLinearCollectionSource();
IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);

marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling);
if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

marshallingStrategy = new StatefulLinearCollectionMarshalling(marshallingStrategy, marshallerData.Shape, numElementsExpression, elementsMarshalling, freeStrategy != FreeStrategy.NoFree);

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatefulFreeMarshalling(marshallingStrategy);
}
}
else
{
marshallingStrategy = new StatelessLinearCollectionSpaceAllocator(marshallerTypeSyntax, nativeType, marshallerData.Shape, numElementsExpression);

FreeStrategy freeStrategy = GetFreeStrategy(info, context);

IElementsMarshallingCollectionSource collectionSource = new StatelessLinearCollectionSource(marshallerTypeSyntax);
if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new StatelessUnmanagedToManagedOwnershipTracking(marshallingStrategy);
}

IElementsMarshalling elementsMarshalling = CreateElementsMarshalling(marshallerData, elementInfo, elementMarshaller, unmanagedElementType, collectionSource);

marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape);
marshallingStrategy = new StatelessLinearCollectionMarshalling(marshallingStrategy, elementsMarshalling, nativeType, marshallerData.Shape, freeStrategy != FreeStrategy.NoFree);

if (marshallerData.Shape.HasFlag(MarshallerShape.CallerAllocatedBuffer))
{
Expand All @@ -334,8 +371,15 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = new StatelessCallerAllocatedBufferMarshalling(marshallingStrategy, marshallerTypeSyntax, bufferElementTypeSyntax, isLinearCollectionMarshalling: true);
}

if (marshallerData.Shape.HasFlag(MarshallerShape.Free))
if (freeStrategy != FreeStrategy.NoFree && marshallerData.Shape.HasFlag(MarshallerShape.Free))
{
marshallingStrategy = new StatelessFreeMarshalling(marshallingStrategy, marshallerTypeSyntax);
}

if (freeStrategy == FreeStrategy.FreeOriginal)
{
marshallingStrategy = new FreeOwnedOriginalValueMarshalling(marshallingStrategy);
}
}

IMarshallingGenerator marshallingGenerator = new CustomTypeMarshallingGenerator(
Expand All @@ -351,6 +395,48 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
return marshallingGenerator;
}

private enum FreeStrategy
{
/// <summary>
/// Free the unmanaged value stored in the native identifier.
/// </summary>
FreeNative,
/// <summary>
/// Free the unmanaged value originally passed into the stub.
/// </summary>
FreeOriginal,
/// <summary>
/// Do not free the unmanaged value, we don't own it.
/// </summary>
NoFree
}

private static FreeStrategy GetFreeStrategy(TypePositionInfo info, StubCodeContext context)
{
// When marshalling from managed to unmanaged, we always own the value in the native identifier.
if (context.Direction == MarshalDirection.ManagedToUnmanaged)
{
return FreeStrategy.FreeNative;
}

// When we're in a case where we don't have state across stages, the parent stub context that can track the state
// will only call our Cleanup stage when we own the value in the native identifier.
if (!context.AdditionalTemporaryStateLivesAcrossStages)
{
return FreeStrategy.FreeNative;
}

// In an unmanaged-to-managed stub where a value is passed by 'ref',
// we own the original value once we replace it with the new value we're passing out to the caller.
if (info.RefKind == RefKind.Ref)
{
return FreeStrategy.FreeOriginal;
}

// In an unmanaged-to-managed stub, we don't take ownership of the value when it isn't passed by 'ref'.
return FreeStrategy.NoFree;
}

private static IElementsMarshalling CreateElementsMarshalling(CustomTypeMarshallerData marshallerData, TypePositionInfo elementInfo, IMarshallingGenerator elementMarshaller, TypeSyntax unmanagedElementType, IElementsMarshallingCollectionSource collectionSource)
{
IElementsMarshalling elementsMarshalling;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ internal interface IElementsMarshallingCollectionSource

internal interface IElementsMarshalling
{
StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCodeContext context);
StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, StubCodeContext context);
}
Expand All @@ -45,7 +45,7 @@ public BlittableElementsMarshalling(TypeSyntax managedElementType, TypeSyntax un
_collectionSource = collectionSource;
}

public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
Expand Down Expand Up @@ -73,7 +73,7 @@ public StatementSyntax GenerateMarshalStatement(TypePositionInfo info, StubCodeC
Argument(destination)));
}

public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
ExpressionSyntax source = CastToManagedIfNecessary(_collectionSource.GetUnmanagedValuesDestination(info, context));

Expand Down Expand Up @@ -175,7 +175,7 @@ public NonBlittableElementsMarshalling(
_collectionSource = collectionSource;
}

public StatementSyntax GenerateByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutMarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
Expand Down Expand Up @@ -259,7 +259,7 @@ public StatementSyntax GenerateUnmarshalStatement(TypePositionInfo info, StubCod
StubCodeContext.Stage.Unmarshal));
}

public StatementSyntax GenerateByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
public StatementSyntax GenerateManagedToUnmanagedByValueOutUnmarshalStatement(TypePositionInfo info, StubCodeContext context)
{
// Use ManagedSource and NativeDestination spans for by-value marshalling since we're just marshalling back the contents,
// not the array itself.
Expand Down Expand Up @@ -356,7 +356,9 @@ public StatementSyntax GenerateElementCleanupStatement(TypePositionInfo info, St
VariableDeclarator(
Identifier(nativeSpanIdentifier))
.WithInitializer(EqualsValueClause(
_collectionSource.GetUnmanagedValuesDestination(info, context)))))),
context.Direction == MarshalDirection.ManagedToUnmanaged
? _collectionSource.GetUnmanagedValuesDestination(info, context)
: _collectionSource.GetUnmanagedValuesSource(info, context)))))),
contentsCleanupStatements);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public IMarshallingGenerator Create(
return s_delegate;

case { MarshallingAttributeInfo: SafeHandleMarshallingInfo(_, bool isAbstract) }:
if (!context.AdditionalTemporaryStateLivesAcrossStages)
if (!context.AdditionalTemporaryStateLivesAcrossStages || context.Direction != MarshalDirection.ManagedToUnmanaged)
{
throw new MarshallingNotSupportedException(info, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,40 +372,36 @@ internal sealed class StatefulLinearCollectionMarshalling : ICustomTypeMarshalli
private readonly MarshallerShape _shape;
private readonly ExpressionSyntax _numElementsExpression;
private readonly IElementsMarshalling _elementsMarshalling;
private readonly bool _cleanupElements;

public StatefulLinearCollectionMarshalling(
ICustomTypeMarshallingStrategy innerMarshaller,
MarshallerShape shape,
ExpressionSyntax numElementsExpression,
IElementsMarshalling elementsMarshalling)
IElementsMarshalling elementsMarshalling,
bool cleanupElements)
{
_innerMarshaller = innerMarshaller;
_shape = shape;
_numElementsExpression = numElementsExpression;
_elementsMarshalling = elementsMarshalling;
_cleanupElements = cleanupElements;
}

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);
public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
if (!_cleanupElements)
{
yield break;
}

StatementSyntax elementCleanup = _elementsMarshalling.GenerateElementCleanupStatement(info, context);

if (!elementCleanup.IsKind(SyntaxKind.EmptyStatement))
{
yield return elementCleanup;
}

if (!_shape.HasFlag(MarshallerShape.Free))
yield break;

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

Expand All @@ -419,9 +415,9 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
yield return statement;
}

if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
{
yield return _elementsMarshalling.GenerateByValueOutMarshalStatement(info, context);
yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutMarshalStatement(info, context);
yield break;
}

Expand All @@ -437,9 +433,10 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo
{
string numElementsIdentifier = MarshallerHelpers.GetNumElementsIdentifier(info, context);

if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
if (context.Direction == MarshalDirection.ManagedToUnmanaged && !info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
yield return _elementsMarshalling.GenerateByValueOutUnmarshalStatement(info, context);
yield return _elementsMarshalling.GenerateManagedToUnmanagedByValueOutUnmarshalStatement(info, context);
yield break;
}

if (!_shape.HasFlag(MarshallerShape.ToManaged))
Expand Down Expand Up @@ -469,4 +466,50 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => true;
}

/// <summary>
/// Marshaller that enables calling the Free method on a stateful marshaller.
/// </summary>
internal sealed class StatefulFreeMarshalling : ICustomTypeMarshallingStrategy
{
private readonly ICustomTypeMarshallingStrategy _innerMarshaller;

public StatefulFreeMarshalling(ICustomTypeMarshallingStrategy innerMarshaller)
{
_innerMarshaller = innerMarshaller;
}

public ManagedTypeInfo AsNativeType(TypePositionInfo info) => _innerMarshaller.AsNativeType(info);

public IEnumerable<StatementSyntax> GenerateCleanupStatements(TypePositionInfo info, StubCodeContext context)
{
foreach (var statement in _innerMarshaller.GenerateCleanupStatements(info, context))
{
yield return statement;
}

string marshaller = StatefulValueMarshalling.GetMarshallerIdentifier(info, context);
// <marshaller>.Free();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(marshaller),
IdentifierName(ShapeMemberNames.Free)),
ArgumentList()));
}
public IEnumerable<StatementSyntax> GenerateGuaranteedUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateGuaranteedUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateMarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateNotifyForSuccessfulInvokeStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateNotifyForSuccessfulInvokeStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinnedMarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinnedMarshalStatements(info, context);
public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GeneratePinStatements(info, context);
public IEnumerable<StatementSyntax> GenerateSetupStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateSetupStatements(info, context);

public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalStatements(info, context);

public IEnumerable<StatementSyntax> GenerateUnmarshalCaptureStatements(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.GenerateUnmarshalCaptureStatements(info, context);

public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context) => _innerMarshaller.UsesNativeIdentifier(info, context);
}
}
Loading