Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use GetResult(out Exception) method for awaits #67487

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions azure-pipelines-official.yml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ stages:
-pack
-sign
-publish
-bootstrap
-bootstrapConfiguration Release
-binaryLog
-configuration $(BuildConfiguration)
-officialBuildId $(Build.BuildNumber)
Expand Down
125 changes: 111 additions & 14 deletions src/Compilers/CSharp/Portable/Binder/Binder_Await.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder placeho
out MethodSymbol? getResult,
getAwaiterGetResultCall: out _,
node,
diagnostics);
diagnostics,
useGetResultException: true);
hasErrors |= hasGetAwaitableErrors;

return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
Expand Down Expand Up @@ -123,8 +124,7 @@ private bool CouldBeAwaited(BoundExpression expression)
return false;
}

return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _,
node: syntax, diagnostics: BindingDiagnosticBag.Discarded);
return GetAwaitableExpressionInfo(expression, node: syntax, diagnostics: BindingDiagnosticBag.Discarded);
}

/// <summary>
Expand Down Expand Up @@ -235,6 +235,19 @@ private bool ReportBadAwaitContext(SyntaxNodeOrToken nodeOrToken, BindingDiagnos
}
}

/// <summary>
/// Finds and validates the required members of an awaitable expression, as described in spec 7.7.7.1.
/// </summary>
/// <returns>True if the expression is awaitable; false otherwise.</returns>
internal bool GetAwaitableExpressionInfo(
BoundExpression expression,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
// PROTOTYPE test all scenarios that rely on this method
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out _, node, diagnostics, useGetResultException: true);
}

/// <summary>
/// Finds and validates the required members of an awaitable expression, as described in spec 7.7.7.1.
/// </summary>
Expand All @@ -245,7 +258,10 @@ internal bool GetAwaitableExpressionInfo(
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, node, diagnostics);
// PROTOTYPE test all scenarios that rely on this method
// Callers (such as logic for async Main binding) who want to use the bound node for GetResult call
// should not use the `GetResult(out Exception)` overload.
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, node, diagnostics, useGetResultException: false);
}

private bool GetAwaitableExpressionInfo(
Expand All @@ -257,7 +273,8 @@ private bool GetAwaitableExpressionInfo(
out MethodSymbol? getResult,
out BoundExpression? getAwaiterGetResultCall,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
BindingDiagnosticBag diagnostics,
bool useGetResultException)
{
Debug.Assert(TypeSymbol.Equals(expression.Type, getAwaiterArgument.Type, TypeCompareKind.ConsiderEverything));

Expand All @@ -272,6 +289,7 @@ private bool GetAwaitableExpressionInfo(
return false;
}

// PROTOTYPE deal with dynamic scenario
if (expression.HasDynamicType())
{
isDynamic = true;
Expand All @@ -284,9 +302,11 @@ private bool GetAwaitableExpressionInfo(
}

TypeSymbol awaiterType = getAwaiter.Type!;
useGetResultException = useGetResultException && !Flags.Includes(BinderFlags.InTryBlockOfTryCatch);
return GetIsCompletedProperty(awaiterType, node, expression.Type!, diagnostics, out isCompleted)
&& AwaiterImplementsINotifyCompletion(awaiterType, node, diagnostics)
&& GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall);
&& (useGetResultException && GetGetResultExceptionMethod(getAwaiter, node, expression.Type!, out getResult, out getAwaiterGetResultCall)
|| GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall));
}

/// <summary>
Expand Down Expand Up @@ -433,43 +453,120 @@ private bool AwaiterImplementsINotifyCompletion(TypeSymbol awaiterType, SyntaxNo
/// </remarks>
private bool GetGetResultMethod(BoundExpression awaiterExpression, SyntaxNode node, TypeSymbol awaitedExpressionType, BindingDiagnosticBag diagnostics, out MethodSymbol? getResultMethod, [NotNullWhen(true)] out BoundExpression? getAwaiterGetResultCall)
{
var awaiterType = awaiterExpression.Type;
getAwaiterGetResultCall = MakeInvocationExpression(node, awaiterExpression, WellKnownMemberNames.GetResult, ImmutableArray<BoundExpression>.Empty, diagnostics);
if (getAwaiterGetResultCall.HasAnyErrors)
if (!ValidateGetResultMethod(awaiterExpression, node, awaitedExpressionType, getAwaiterGetResultCall, allowExtensions: false, diagnostics, getResultMethod: out getResultMethod))
{
getResultMethod = null;
getAwaiterGetResultCall = null;
return false;
}

return true;
}

private static bool ValidateGetResultMethod(BoundExpression awaiterExpression, SyntaxNode node, TypeSymbol awaitedExpressionType,
BoundExpression getAwaiterGetResultCall, bool allowExtensions, BindingDiagnosticBag diagnostics, [NotNullWhen(true)] out MethodSymbol? getResultMethod)
{
var awaiterType = awaiterExpression.Type;
getResultMethod = null;
if (getAwaiterGetResultCall.HasAnyErrors)
{
return false;
}

RoslynDebug.Assert(awaiterType is object);
if (getAwaiterGetResultCall.Kind != BoundKind.Call)
{
Error(diagnostics, ErrorCode.ERR_NoSuchMember, node, awaiterType, WellKnownMemberNames.GetResult);
getResultMethod = null;
getAwaiterGetResultCall = null;
return false;
}

getResultMethod = ((BoundCall)getAwaiterGetResultCall).Method;
if (getResultMethod.IsExtensionMethod)
if (!allowExtensions && getResultMethod.IsExtensionMethod)
{
Error(diagnostics, ErrorCode.ERR_NoSuchMember, node, awaiterType, WellKnownMemberNames.GetResult);
getResultMethod = null;
getAwaiterGetResultCall = null;
return false;
}

if (HasOptionalOrVariableParameters(getResultMethod) || getResultMethod.IsConditional)
{
Error(diagnostics, ErrorCode.ERR_BadAwaiterPattern, node, awaiterType, awaitedExpressionType);
return false;
}

return true;
}

/// <summary>
/// Finds the GetResult(out Exception) method of an Awaiter type.
/// </summary>
private bool GetGetResultExceptionMethod(BoundExpression awaiterExpression, SyntaxNode node, TypeSymbol awaitedExpressionType,
out MethodSymbol? getResultMethod, [NotNullWhen(true)] out BoundExpression? getAwaiterGetResultCall)
{
// PROTOTYPE should this binding logic be conditional on LangVer?
Debug.Assert(awaiterExpression.Type is not null);

var discarded = BindingDiagnosticBag.Discarded;
getAwaiterGetResultCall = tryMakeGetResultExceptionInvocation(node, awaiterExpression, discarded);

if (getAwaiterGetResultCall is null ||
!ValidateGetResultMethod(awaiterExpression, node, awaitedExpressionType, getAwaiterGetResultCall, allowExtensions: true, discarded, getResultMethod: out getResultMethod))
{
getResultMethod = null;
getAwaiterGetResultCall = null;
return false;
}

// The lack of a GetResult method will be reported by ValidateGetResult().
return true;

// Find the GetResult(out Exception) method.
BoundExpression? tryMakeGetResultExceptionInvocation(SyntaxNode node, BoundExpression awaiterExpression, BindingDiagnosticBag diagnostics)
{
Debug.Assert(awaiterExpression.Type is not null);
Debug.Assert(!awaiterExpression.Type.IsDynamic());

const string methodName = WellKnownMemberNames.GetResult;
var memberAccess = BindInstanceMemberAccess(
node, node, awaiterExpression, methodName, rightArity: 0,
typeArgumentsSyntax: default, typeArgumentsWithAnnotations: default,
invoked: true, indexed: false, diagnostics: diagnostics);

// PROTOTYPE we should not drop these diagnostics
//memberAccess = CheckValue(memberAccess, BindValueKind.RValueOrMethodGroup, diagnostics);
memberAccess.WasCompilerGenerated = true;

if (memberAccess.Kind != BoundKind.MethodGroup)
{
return null;
}

var analyzedArguments = AnalyzedArguments.GetInstance();

try
{
// PROTOTYPE handle missing System.Exception type
var outPlaceholder = new BoundGetResultOutExceptionPlaceholder(node, GetWellKnownType(WellKnownType.System_Exception, diagnostics, node));
analyzedArguments.Arguments.Add(outPlaceholder);
analyzedArguments.RefKinds.Add(RefKind.Out);

BoundExpression getAwaiterGetResultCall = BindMethodGroupInvocation(
node, node, methodName, (BoundMethodGroup)memberAccess, analyzedArguments, diagnostics, queryClause: null,
allowUnexpandedForm: true, anyApplicableCandidates: out bool anyApplicableCandidates);

getAwaiterGetResultCall.WasCompilerGenerated = true;

if (!anyApplicableCandidates) // PROTOTYPE is this necessary/useful?
{
return null;
}

return getAwaiterGetResultCall;
}
finally
{
analyzedArguments.Free();
}
}
}

private static bool HasOptionalOrVariableParameters(MethodSymbol method)
Expand Down
5 changes: 5 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ internal partial class BoundAwaitableValuePlaceholder
public sealed override bool IsEquivalentToThisReference => false; // Preserving old behavior
}

internal partial class BoundGetResultOutExceptionPlaceholder
{
public sealed override bool IsEquivalentToThisReference => throw ExceptionUtilities.Unreachable();
}

internal partial class BoundDisposableValuePlaceholder
{
public sealed override bool IsEquivalentToThisReference => false;
Expand Down
9 changes: 9 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@
<Field Name="IsDiscardExpression" Type="bool" Null="NotApplicable"/>
</Node>

<!--
This node is used to represent the out argument to the `GetResult(out System.Exception)`
invocation in an async method.
It does not survive the initial binding.
-->
<Node Name="BoundGetResultOutExceptionPlaceholder" Base="BoundValuePlaceholderBase">
<Field Name="Type" Type="TypeSymbol" Override="true" Null="disallow"/>
</Node>

<!--
In a tuple binary operator, this node is used to represent tuple elements in a tuple binary
operator, and to represent an element-wise comparison result to convert back to bool.
Expand Down
26 changes: 17 additions & 9 deletions src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,10 @@ private static BoundNode Rewrite(ImmutableDictionary<BoundExpression, (Nullabili
remappedSymbolsBuilder.AddRange(remappedSymbols);
}
var rewriter = new NullabilityRewriter(updatedNullabilities, snapshotManager, remappedSymbolsBuilder);
var rewrittenNode = rewriter.Visit(node);
//var rewrittenNode = rewriter.Visit(node);
remappedSymbols = remappedSymbolsBuilder.ToImmutable();
return rewrittenNode;
//return rewrittenNode;
return node;
}

private static bool HasRequiredLanguageVersion(CSharpCompilation compilation)
Expand Down Expand Up @@ -1739,10 +1740,10 @@ private static void Analyze(
Debug.Assert(walker._variables.Id == initialState.Value.Id);
}
#endif
bool badRegion = false;
ImmutableArray<PendingBranch> returns = walker.Analyze(ref badRegion, initialState);
diagnostics?.AddRange(walker.Diagnostics);
Debug.Assert(!badRegion);
//bool badRegion = false;
//ImmutableArray<PendingBranch> returns = walker.Analyze(ref badRegion, initialState);
//diagnostics?.AddRange(walker.Diagnostics);
//Debug.Assert(!badRegion);
}
catch (CancelledByStackGuardException ex) when (diagnostics != null)
{
Expand Down Expand Up @@ -10811,10 +10812,17 @@ private TypeWithState InferResultNullabilityOfBinaryLogicalOperator(BoundExpress
// Proper handling of this is additional work which only benefits a very uncommon scenario,
// so we will just use the originally bound GetResult method in this case.
var getResult = awaitableInfo.GetResult;
var reinferredGetResult = _visitResult.RValueType.Type is NamedTypeSymbol taskAwaiterType
? getResult.OriginalDefinition.AsMember(taskAwaiterType)
: getResult;

MethodSymbol? reinferredGetResult;
if (_visitResult.RValueType.Type is NamedTypeSymbol taskAwaiterType)
{
// TODO2 crash here
reinferredGetResult = getResult.OriginalDefinition.AsMember(taskAwaiterType);
}
else
{
reinferredGetResult = getResult;
}
SetResultType(node, reinferredGetResult.ReturnTypeWithAnnotations.ToTypeWithState());
}

Expand Down
Loading