Skip to content

Commit

Permalink
Ensure proper receiver value is used for a constrained call invocation (
Browse files Browse the repository at this point in the history
#65642)

Related to #63221.
  • Loading branch information
AlekseyTs authored Dec 6, 2022
1 parent 1c0f249 commit 19704e6
Show file tree
Hide file tree
Showing 32 changed files with 10,254 additions and 7,523 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ public static RefKind GetRefKind(this BoundExpression node)
case BoundKind.PropertyAccess:
return ((BoundPropertyAccess)node).PropertySymbol.RefKind;

case BoundKind.IndexerAccess:
return ((BoundIndexerAccess)node).Indexer.RefKind;

case BoundKind.ImplicitIndexerAccess:
return ((BoundImplicitIndexerAccess)node).IndexerOrSliceAccess.GetRefKind();

case BoundKind.ObjectInitializerMember:
var member = (BoundObjectInitializerMember)node;
if (member.HasErrors)
Expand Down
4 changes: 2 additions & 2 deletions src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1716,10 +1716,10 @@
<Field Name="Id" Type="int"/>
</Node>

<!-- This node represents a complex receiver for a conditional access.
<!-- This node represents a complex receiver for a call, or a conditional access.
At runtime, when its type is a value type, ValueTypeReceiver should be used as a receiver.
Otherwise, ReferenceTypeReceiver should be used.
This kind of receiver is created only by Async rewriter.
This kind of receiver is created only by SpillSequenceSpiller rewriter.
-->
<Node Name="BoundComplexConditionalReceiver" Base="BoundExpression">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
Expand Down
146 changes: 145 additions & 1 deletion src/Compilers/CSharp/Portable/CodeGen/EmitExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
var receiver = call.ReceiverOpt;
var arguments = call.Arguments;
LocalDefinition tempOpt = null;
LocalDefinition cloneTemp = null;

Debug.Assert(!method.IsStatic && method.RequiresInstanceReceiver);

Expand Down Expand Up @@ -1655,6 +1656,51 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
CallKind.ConstrainedCallVirt;

tempOpt = EmitReceiverRef(receiver, callKind == CallKind.ConstrainedCallVirt ? AddressKind.Constrained : AddressKind.Writeable);

if (callKind == CallKind.ConstrainedCallVirt && tempOpt is null && !receiverType.IsValueType &&
!ReceiverIsKnownToReferToTempIfReferenceType(call.ReceiverOpt) &&
!IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(call.Arguments))
{
// A case where T is actually a class must be handled specially.
// Taking a reference to a class instance is fragile because the value behind the
// reference might change while arguments are evaluated. However, the call should be
// performed on the instance that is behind reference at the time we push the
// reference to the stack. So, for a class we need to emit a reference to a temporary
// location, rather than to the original location

// Struct values are never nulls.
// We will emit a check for such case, but the check is really a JIT-time
// constant since JIT will know if T is a struct or not.

// if ((object)default(T) == null)
// {
// temp = receiverRef
// receiverRef = ref temp
// }

object whenNotNullLabel = null;

if (!receiverType.IsReferenceType)
{
// if ((object)default(T) == null)
EmitDefaultValue(receiverType, true, receiver.Syntax);
EmitBox(receiverType, receiver.Syntax);
whenNotNullLabel = new object();
_builder.EmitBranch(ILOpCode.Brtrue, whenNotNullLabel);
}

// temp = receiverRef
// receiverRef = ref temp
EmitLoadIndirect(receiverType, receiver.Syntax);
cloneTemp = AllocateTemp(receiverType, receiver.Syntax);
_builder.EmitLocalStore(cloneTemp);
_builder.EmitLocalAddress(cloneTemp);

if (whenNotNullLabel is not null)
{
_builder.MarkLabel(whenNotNullLabel);
}
}
}

// When emitting a callvirt to a virtual method we always emit the method info of the
Expand Down Expand Up @@ -1730,6 +1776,104 @@ private void EmitInstanceCallExpression(BoundCall call, UseKind useKind)
EmitCallCleanup(call.Syntax, useKind, method);

FreeOptTemp(tempOpt);
FreeOptTemp(cloneTemp);
}

internal static bool IsPossibleReferenceTypeReceiverOfConstrainedCall(BoundExpression receiver)
{
var receiverType = receiver.Type;

if (receiverType.IsVerifierReference() || receiverType.IsVerifierValue())
{
return false;
}

return !receiverType.IsValueType;
}

internal static bool ReceiverIsKnownToReferToTempIfReferenceType(BoundExpression receiver)
{
while (receiver is BoundSequence sequence)
{
receiver = sequence.Value;
}

if (receiver is
BoundLocal { LocalSymbol.IsKnownToReferToTempIfReferenceType: true } or
BoundComplexConditionalReceiver)
{
return true;
}

return false;
}

internal static bool IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(ImmutableArray<BoundExpression> arguments)
{
return arguments.All(isSafeToDereferenceReceiverRefAfterEvaluatingArgument);

static bool isSafeToDereferenceReceiverRefAfterEvaluatingArgument(BoundExpression expression)
{
var current = expression;
while (true)
{
if (current.ConstantValue != null)
{
return true;
}

switch (current.Kind)
{
default:
return false;
case BoundKind.TypeExpression:
case BoundKind.Parameter:
case BoundKind.Local:
case BoundKind.ThisReference:
return true;
case BoundKind.FieldAccess:
{
var field = (BoundFieldAccess)current;
current = field.ReceiverOpt;
if (current is null)
{
return true;
}

break;
}
case BoundKind.PassByCopy:
current = ((BoundPassByCopy)current).Expression;
break;
case BoundKind.BinaryOperator:
{
BoundBinaryOperator b = (BoundBinaryOperator)current;
Debug.Assert(!b.OperatorKind.IsUserDefined());

if (b.OperatorKind.IsUserDefined() || !isSafeToDereferenceReceiverRefAfterEvaluatingArgument(b.Right))
{
return false;
}

current = b.Left;
break;
}
case BoundKind.Conversion:
{
BoundConversion conv = (BoundConversion)current;
Debug.Assert(!conv.ConversionKind.IsUserDefinedConversion());

if (conv.ConversionKind.IsUserDefinedConversion())
{
return false;
}

current = conv.Operand;
break;
}
}
}
}
}

private bool IsReadOnlyCall(MethodSymbol method, NamedTypeSymbol methodContainingType)
Expand Down Expand Up @@ -1759,7 +1903,7 @@ private bool IsReadOnlyCall(MethodSymbol method, NamedTypeSymbol methodContainin
// returns true when receiver is already a ref.
// in such cases calling through a ref could be preferred over
// calling through indirectly loaded value.
private bool IsRef(BoundExpression receiver)
internal static bool IsRef(BoundExpression receiver)
{
switch (receiver.Kind)
{
Expand Down
30 changes: 27 additions & 3 deletions src/Compilers/CSharp/Portable/CodeGen/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ public override BoundNode VisitCall(BoundCall node)
// assume we will need an address (that will prevent scheduling of receiver).
if (method.RequiresInstanceReceiver)
{
receiver = VisitCallReceiver(receiver);
receiver = VisitCallOrConditionalAccessReceiver(receiver, node);
}
else
{
Expand All @@ -1132,9 +1132,28 @@ public override BoundNode VisitCall(BoundCall node)
return node.Update(receiver, method, rewrittenArguments);
}

private BoundExpression VisitCallReceiver(BoundExpression receiver)
private BoundExpression VisitCallOrConditionalAccessReceiver(BoundExpression receiver, BoundCall callOpt)
{
var receiverType = receiver.Type;

if (callOpt is { } call &&
CodeGenerator.IsRef(receiver) &&
CodeGenerator.IsPossibleReferenceTypeReceiverOfConstrainedCall(receiver) &&
!CodeGenerator.IsSafeToDereferenceReceiverRefAfterEvaluatingArguments(call.Arguments))
{
var unwrappedSequence = receiver;

while (unwrappedSequence is BoundSequence sequence)
{
unwrappedSequence = sequence.Value;
}

if (unwrappedSequence is BoundLocal { LocalSymbol: { RefKind: not RefKind.None } localSymbol })
{
ShouldNotSchedule(localSymbol); // Otherwise CodeGenerator is unable to apply proper fixups
}
}

ExprContext context;

if (receiverType.IsReferenceType)
Expand Down Expand Up @@ -1494,7 +1513,7 @@ public override BoundNode VisitNullCoalescingOperator(BoundNullCoalescingOperato
public override BoundNode VisitLoweredConditionalAccess(BoundLoweredConditionalAccess node)
{
var origStack = StackDepth();
BoundExpression receiver = VisitCallReceiver(node.Receiver);
BoundExpression receiver = VisitCallOrConditionalAccessReceiver(node.Receiver, callOpt: null);

var cookie = GetStackStateCookie(); // implicit branch here

Expand Down Expand Up @@ -2210,6 +2229,11 @@ internal override bool IsPinned
get { return false; }
}

internal override bool IsKnownToReferToTempIfReferenceType
{
get { return false; }
}

public override Symbol ContainingSymbol
{
get { throw new NotImplementedException(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public override bool Equals(Symbol obj, TypeCompareKind compareKind)
internal override bool IsCompilerGenerated => true;
internal override bool IsImportedFromMetadata => false;
internal override bool IsPinned => false;
internal override bool IsKnownToReferToTempIfReferenceType => false;
public override RefKind RefKind => RefKind.None;
internal override SynthesizedLocalKind SynthesizedKind => throw ExceptionUtilities.Unreachable();
internal override ConstantValue GetConstantValue(SyntaxNode node, LocalSymbol inProgress, BindingDiagnosticBag diagnostics = null) => null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ private BoundExpression MakePropertyAssignment(
ArrayBuilder<LocalSymbol>? argTempsBuilder = null;
arguments = VisitArgumentsAndCaptureReceiverIfNeeded(
ref rewrittenReceiver,
captureReceiverForMultipleInvocations: false,
captureReceiverMode: ReceiverCaptureMode.Default,
arguments,
property,
argsToParamsOpt,
Expand Down
Loading

0 comments on commit 19704e6

Please sign in to comment.