diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs index 895eb9201..3075d8fdd 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs @@ -382,6 +382,32 @@ void VerifyOnUIThread() { this.VerifyCSharpDiagnostic(test, this.expect); } + [Fact] + public void RequiresUIThread_NotTransitiveThroughAsyncCalls() + { + var test = @" +using System; +using System.Threading.Tasks; +using Microsoft.VisualStudio.Threading; + +class Test { + private JoinableTaskFactory jtf; + + private void ShowToolWindow(object sender, EventArgs e) { + jtf.RunAsync(async delegate { + await FooAsync(); // this line is what adds the VSTHRD010 diagnostic + }); + } + + private async Task FooAsync() { + await jtf.SwitchToMainThreadAsync(); + } +} +"; + + this.VerifyCSharpDiagnostic(test); + } + [Fact] public void InvokeVsSolutionAfterSwitchedToMainThreadAsync() { diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs index d9640513f..fdd80232f 100644 --- a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs +++ b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs @@ -115,6 +115,7 @@ public override void Initialize(AnalysisContext context) var typesRequiringMainThread = CommonInterest.ReadTypes(compilationStartContext, CommonInterest.FileNamePatternForTypesRequiringMainThread).ToImmutableArray(); var methodsDeclaringUIThreadRequirement = new HashSet<IMethodSymbol>(); + var methodsAssertingUIThreadRequirement = new HashSet<IMethodSymbol>(); var callerToCalleeMap = new Dictionary<IMethodSymbol, HashSet<IMethodSymbol>>(); compilationStartContext.RegisterCodeBlockStartAction<SyntaxKind>(codeBlockStartContext => @@ -125,6 +126,7 @@ public override void Initialize(AnalysisContext context) MainThreadSwitchingMethods = mainThreadSwitchingMethods, TypesRequiringMainThread = typesRequiringMainThread, MethodsDeclaringUIThreadRequirement = methodsDeclaringUIThreadRequirement, + MethodsAssertingUIThreadRequirement = methodsAssertingUIThreadRequirement, }; codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeInvocation), SyntaxKind.InvocationExpression); codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeMemberAccess), SyntaxKind.SimpleMemberAccessExpression); @@ -140,7 +142,7 @@ public override void Initialize(AnalysisContext context) compilationStartContext.RegisterCompilationEndAction(compilationEndContext => { var calleeToCallerMap = CreateCalleeToCallerMap(callerToCalleeMap); - var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsDeclaringUIThreadRequirement, calleeToCallerMap); + var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsAssertingUIThreadRequirement, calleeToCallerMap); foreach (var implicitUserMethod in transitiveClosureOfMainThreadRequiringMethods.Except(methodsDeclaringUIThreadRequirement)) { var declarationSyntax = implicitUserMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(compilationEndContext.CancellationToken); @@ -268,6 +270,8 @@ private class MethodAnalyzer internal HashSet<IMethodSymbol> MethodsDeclaringUIThreadRequirement { get; set; } + internal HashSet<IMethodSymbol> MethodsAssertingUIThreadRequirement { get; set; } + internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context) { var invocationSyntax = (InvocationExpressionSyntax)context.Node; @@ -277,7 +281,9 @@ internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context) var methodDeclaration = context.Node.FirstAncestorOrSelf<SyntaxNode>(n => CommonInterest.MethodSyntaxKinds.Contains(n.Kind())); if (methodDeclaration != null) { - if (this.MainThreadAssertingMethods.Contains(invokedMethod) || this.MainThreadSwitchingMethods.Contains(invokedMethod)) + bool assertsMainThread = this.MainThreadAssertingMethods.Contains(invokedMethod); + bool switchesToMainThread = this.MainThreadSwitchingMethods.Contains(invokedMethod); + if (assertsMainThread || switchesToMainThread) { if (context.ContainingSymbol is IMethodSymbol callingMethod) { @@ -285,6 +291,14 @@ internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context) { this.MethodsDeclaringUIThreadRequirement.Add(callingMethod); } + + if (assertsMainThread) + { + lock (this.MethodsAssertingUIThreadRequirement) + { + this.MethodsAssertingUIThreadRequirement.Add(callingMethod); + } + } } this.methodDeclarationNodes = this.methodDeclarationNodes.SetItem(methodDeclaration, ThreadingContext.MainThread);