Skip to content

Commit

Permalink
Merge pull request #227 from Microsoft/dev/andarno/fix226
Browse files Browse the repository at this point in the history
Fix VSTHRD010 transitivity to not mis-fire when invoking async methods
  • Loading branch information
AArnott authored Mar 31, 2018
2 parents e221f73 + b449cf7 commit 12b776d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,32 @@ void VerifyOnUIThread() {
this.VerifyCSharpDiagnostic(test, this.expect);
}

[Fact]
public void RequiresUIThread_NotTransitiveThroughAsyncCalls()
{
var test = @"
using System;
using System.Threading.Tasks;
using Microsoft.VisualStudio.Threading;
class Test {
private JoinableTaskFactory jtf;
private void ShowToolWindow(object sender, EventArgs e) {
jtf.RunAsync(async delegate {
await FooAsync(); // this line is what adds the VSTHRD010 diagnostic
});
}
private async Task FooAsync() {
await jtf.SwitchToMainThreadAsync();
}
}
";

this.VerifyCSharpDiagnostic(test);
}

[Fact]
public void InvokeVsSolutionAfterSwitchedToMainThreadAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ public override void Initialize(AnalysisContext context)
var typesRequiringMainThread = CommonInterest.ReadTypes(compilationStartContext, CommonInterest.FileNamePatternForTypesRequiringMainThread).ToImmutableArray();

var methodsDeclaringUIThreadRequirement = new HashSet<IMethodSymbol>();
var methodsAssertingUIThreadRequirement = new HashSet<IMethodSymbol>();
var callerToCalleeMap = new Dictionary<IMethodSymbol, HashSet<IMethodSymbol>>();

compilationStartContext.RegisterCodeBlockStartAction<SyntaxKind>(codeBlockStartContext =>
Expand All @@ -125,6 +126,7 @@ public override void Initialize(AnalysisContext context)
MainThreadSwitchingMethods = mainThreadSwitchingMethods,
TypesRequiringMainThread = typesRequiringMainThread,
MethodsDeclaringUIThreadRequirement = methodsDeclaringUIThreadRequirement,
MethodsAssertingUIThreadRequirement = methodsAssertingUIThreadRequirement,
};
codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeInvocation), SyntaxKind.InvocationExpression);
codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeMemberAccess), SyntaxKind.SimpleMemberAccessExpression);
Expand All @@ -140,7 +142,7 @@ public override void Initialize(AnalysisContext context)
compilationStartContext.RegisterCompilationEndAction(compilationEndContext =>
{
var calleeToCallerMap = CreateCalleeToCallerMap(callerToCalleeMap);
var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsDeclaringUIThreadRequirement, calleeToCallerMap);
var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsAssertingUIThreadRequirement, calleeToCallerMap);
foreach (var implicitUserMethod in transitiveClosureOfMainThreadRequiringMethods.Except(methodsDeclaringUIThreadRequirement))
{
var declarationSyntax = implicitUserMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(compilationEndContext.CancellationToken);
Expand Down Expand Up @@ -268,6 +270,8 @@ private class MethodAnalyzer

internal HashSet<IMethodSymbol> MethodsDeclaringUIThreadRequirement { get; set; }

internal HashSet<IMethodSymbol> MethodsAssertingUIThreadRequirement { get; set; }

internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
{
var invocationSyntax = (InvocationExpressionSyntax)context.Node;
Expand All @@ -277,14 +281,24 @@ internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
var methodDeclaration = context.Node.FirstAncestorOrSelf<SyntaxNode>(n => CommonInterest.MethodSyntaxKinds.Contains(n.Kind()));
if (methodDeclaration != null)
{
if (this.MainThreadAssertingMethods.Contains(invokedMethod) || this.MainThreadSwitchingMethods.Contains(invokedMethod))
bool assertsMainThread = this.MainThreadAssertingMethods.Contains(invokedMethod);
bool switchesToMainThread = this.MainThreadSwitchingMethods.Contains(invokedMethod);
if (assertsMainThread || switchesToMainThread)
{
if (context.ContainingSymbol is IMethodSymbol callingMethod)
{
lock (this.MethodsDeclaringUIThreadRequirement)
{
this.MethodsDeclaringUIThreadRequirement.Add(callingMethod);
}

if (assertsMainThread)
{
lock (this.MethodsAssertingUIThreadRequirement)
{
this.MethodsAssertingUIThreadRequirement.Add(callingMethod);
}
}
}

this.methodDeclarationNodes = this.methodDeclarationNodes.SetItem(methodDeclaration, ThreadingContext.MainThread);
Expand Down

0 comments on commit 12b776d

Please sign in to comment.