diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs
index 895eb9201..3075d8fdd 100644
--- a/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs
+++ b/src/Microsoft.VisualStudio.Threading.Analyzers.Tests/VSTHRD010MainThreadUsageAnalyzerTests.cs
@@ -382,6 +382,32 @@ void VerifyOnUIThread() {
             this.VerifyCSharpDiagnostic(test, this.expect);
         }
 
+        [Fact]
+        public void RequiresUIThread_NotTransitiveThroughAsyncCalls()
+        {
+            var test = @"
+using System;
+using System.Threading.Tasks;
+using Microsoft.VisualStudio.Threading;
+
+class Test {
+    private JoinableTaskFactory jtf;
+
+    private void ShowToolWindow(object sender, EventArgs e) {
+        jtf.RunAsync(async delegate {
+            await FooAsync(); // this line is what adds the VSTHRD010 diagnostic
+        });
+    }
+
+    private async Task FooAsync() {
+        await jtf.SwitchToMainThreadAsync();
+    }
+}
+";
+
+            this.VerifyCSharpDiagnostic(test);
+        }
+
         [Fact]
         public void InvokeVsSolutionAfterSwitchedToMainThreadAsync()
         {
diff --git a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs
index d9640513f..fdd80232f 100644
--- a/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs
+++ b/src/Microsoft.VisualStudio.Threading.Analyzers/VSTHRD010MainThreadUsageAnalyzer.cs
@@ -115,6 +115,7 @@ public override void Initialize(AnalysisContext context)
                 var typesRequiringMainThread = CommonInterest.ReadTypes(compilationStartContext, CommonInterest.FileNamePatternForTypesRequiringMainThread).ToImmutableArray();
 
                 var methodsDeclaringUIThreadRequirement = new HashSet<IMethodSymbol>();
+                var methodsAssertingUIThreadRequirement = new HashSet<IMethodSymbol>();
                 var callerToCalleeMap = new Dictionary<IMethodSymbol, HashSet<IMethodSymbol>>();
 
                 compilationStartContext.RegisterCodeBlockStartAction<SyntaxKind>(codeBlockStartContext =>
@@ -125,6 +126,7 @@ public override void Initialize(AnalysisContext context)
                         MainThreadSwitchingMethods = mainThreadSwitchingMethods,
                         TypesRequiringMainThread = typesRequiringMainThread,
                         MethodsDeclaringUIThreadRequirement = methodsDeclaringUIThreadRequirement,
+                        MethodsAssertingUIThreadRequirement = methodsAssertingUIThreadRequirement,
                     };
                     codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeInvocation), SyntaxKind.InvocationExpression);
                     codeBlockStartContext.RegisterSyntaxNodeAction(Utils.DebuggableWrapper(methodAnalyzer.AnalyzeMemberAccess), SyntaxKind.SimpleMemberAccessExpression);
@@ -140,7 +142,7 @@ public override void Initialize(AnalysisContext context)
                 compilationStartContext.RegisterCompilationEndAction(compilationEndContext =>
                 {
                     var calleeToCallerMap = CreateCalleeToCallerMap(callerToCalleeMap);
-                    var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsDeclaringUIThreadRequirement, calleeToCallerMap);
+                    var transitiveClosureOfMainThreadRequiringMethods = GetTransitiveClosureOfMainThreadRequiringMethods(methodsAssertingUIThreadRequirement, calleeToCallerMap);
                     foreach (var implicitUserMethod in transitiveClosureOfMainThreadRequiringMethods.Except(methodsDeclaringUIThreadRequirement))
                     {
                         var declarationSyntax = implicitUserMethod.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax(compilationEndContext.CancellationToken);
@@ -268,6 +270,8 @@ private class MethodAnalyzer
 
             internal HashSet<IMethodSymbol> MethodsDeclaringUIThreadRequirement { get; set; }
 
+            internal HashSet<IMethodSymbol> MethodsAssertingUIThreadRequirement { get; set; }
+
             internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
             {
                 var invocationSyntax = (InvocationExpressionSyntax)context.Node;
@@ -277,7 +281,9 @@ internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
                     var methodDeclaration = context.Node.FirstAncestorOrSelf<SyntaxNode>(n => CommonInterest.MethodSyntaxKinds.Contains(n.Kind()));
                     if (methodDeclaration != null)
                     {
-                        if (this.MainThreadAssertingMethods.Contains(invokedMethod) || this.MainThreadSwitchingMethods.Contains(invokedMethod))
+                        bool assertsMainThread = this.MainThreadAssertingMethods.Contains(invokedMethod);
+                        bool switchesToMainThread = this.MainThreadSwitchingMethods.Contains(invokedMethod);
+                        if (assertsMainThread || switchesToMainThread)
                         {
                             if (context.ContainingSymbol is IMethodSymbol callingMethod)
                             {
@@ -285,6 +291,14 @@ internal void AnalyzeInvocation(SyntaxNodeAnalysisContext context)
                                 {
                                     this.MethodsDeclaringUIThreadRequirement.Add(callingMethod);
                                 }
+
+                                if (assertsMainThread)
+                                {
+                                    lock (this.MethodsAssertingUIThreadRequirement)
+                                    {
+                                        this.MethodsAssertingUIThreadRequirement.Add(callingMethod);
+                                    }
+                                }
                             }
 
                             this.methodDeclarationNodes = this.methodDeclarationNodes.SetItem(methodDeclaration, ThreadingContext.MainThread);