diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs b/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs index 8a35b72734..3506ac7e0f 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs @@ -89,7 +89,7 @@ public static IApplicationBuilder UseMvc( var routes = new RouteBuilder(app) { - DefaultHandler = new MvcRouteHandler(), + DefaultHandler = app.ApplicationServices.GetRequiredService(), }; configureRoutes(routes); diff --git a/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs index 162b986b59..078efd9b94 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs @@ -207,6 +207,11 @@ internal static void AddMvcCoreServices(IServiceCollection services) services.TryAddSingleton(ArrayPool.Shared); services.TryAddSingleton(ArrayPool.Shared); services.TryAddSingleton(); + + // + // Setup default handler + // + services.TryAddSingleton(); } private static void ConfigureDefaultServices(IServiceCollection services) diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcRouteHandler.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcRouteHandler.cs index a8b372f860..efcf9e539a 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcRouteHandler.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcRouteHandler.cs @@ -10,21 +10,44 @@ using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.Tree; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Mvc.Internal { public class MvcRouteHandler : IRouter { - private bool _servicesRetrieved; - private IActionContextAccessor _actionContextAccessor; private IActionInvokerFactory _actionInvokerFactory; private IActionSelector _actionSelector; private ILogger _logger; private DiagnosticSource _diagnosticSource; + public MvcRouteHandler( + IActionInvokerFactory actionInvokerFactory, + IActionSelector actionSelector, + DiagnosticSource diagnosticSource, + ILoggerFactory loggerFactory) + : this(actionInvokerFactory, actionSelector, diagnosticSource, loggerFactory, actionContextAccessor: null) + { + } + + public MvcRouteHandler( + IActionInvokerFactory actionInvokerFactory, + IActionSelector actionSelector, + DiagnosticSource diagnosticSource, + ILoggerFactory loggerFactory, + IActionContextAccessor actionContextAccessor) + { + // The IActionContextAccessor is optional. We want to avoid the overhead of using CallContext + // if possible. + _actionContextAccessor = actionContextAccessor; + + _actionInvokerFactory = actionInvokerFactory; + _actionSelector = actionSelector; + _diagnosticSource = diagnosticSource; + _logger = loggerFactory.CreateLogger(); + } + public VirtualPathData GetVirtualPath(VirtualPathContext context) { if (context == null) @@ -43,8 +66,6 @@ public Task RouteAsync(RouteContext context) throw new ArgumentNullException(nameof(context)); } - EnsureServices(context.HttpContext); - var actionDescriptor = _actionSelector.Select(context); if (actionDescriptor == null) { @@ -107,28 +128,5 @@ private async Task InvokeActionAsync(HttpContext httpContext, ActionDescriptor a _diagnosticSource.AfterAction(actionDescriptor, httpContext, routeData); } } - - private void EnsureServices(HttpContext context) - { - if (_servicesRetrieved) - { - return; - } - - var services = context.RequestServices; - - // The IActionContextAccessor is optional. We want to avoid the overhead of using CallContext - // if possible. - _actionContextAccessor = services.GetService(); - - _actionInvokerFactory = services.GetRequiredService(); - _actionSelector = services.GetRequiredService(); - _diagnosticSource = services.GetRequiredService(); - - var factory = services.GetRequiredService(); - _logger = factory.CreateLogger(); - - _servicesRetrieved = true; - } } } diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Infrastructure/MvcRouteHandlerTests.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Infrastructure/MvcRouteHandlerTests.cs index e42b9ff808..406c19feb3 100644 --- a/test/Microsoft.AspNetCore.Mvc.Core.Test/Infrastructure/MvcRouteHandlerTests.cs +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Infrastructure/MvcRouteHandlerTests.cs @@ -32,9 +32,9 @@ public async Task RouteHandler_Success_LogsCorrectValues() .SetupGet(ad => ad.DisplayName) .Returns(displayName); - var context = CreateRouteContext(actionDescriptor: actionDescriptor.Object, loggerFactory: loggerFactory); + var context = CreateRouteContext(); - var handler = new MvcRouteHandler(); + var handler = CreateMvcRouteHandler(actionDescriptor: actionDescriptor.Object, loggerFactory: loggerFactory); await handler.RouteAsync(context); // Act @@ -62,11 +62,12 @@ public async Task RouteAsync_FailOnNoAction_LogsCorrectValues() .Setup(a => a.Select(It.IsAny())) .Returns(null); - var context = CreateRouteContext( + var context = CreateRouteContext(); + + var handler = CreateMvcRouteHandler( actionSelector: mockActionSelector.Object, loggerFactory: loggerFactory); - var handler = new MvcRouteHandler(); var expectedMessage = "No actions matched the current request"; // Act @@ -95,8 +96,8 @@ public async Task RouteHandler_RemovesRouteGroupFromRouteValues() return invoker.Object; }); - var context = CreateRouteContext(invokerFactory: invokerFactory.Object); - var handler = new MvcRouteHandler(); + var context = CreateRouteContext(); + var handler = CreateMvcRouteHandler(invokerFactory: invokerFactory.Object); var originalRouteData = context.RouteData; originalRouteData.Values.Add(TreeRouter.RouteGroupKey, "/Home/Test"); @@ -117,10 +118,10 @@ public async Task RouteHandler_WritesDiagnostic_ActionSelected() // Arrange var listener = new TestDiagnosticListener(); - var context = CreateRouteContext(diagnosticListener: listener); + var context = CreateRouteContext(); context.RouteData.Values.Add("tag", "value"); - var handler = new MvcRouteHandler(); + var handler = CreateMvcRouteHandler(diagnosticListener: listener); await handler.RouteAsync(context); // Act @@ -143,9 +144,9 @@ public async Task RouteHandler_WritesDiagnostic_ActionInvoked() // Arrange var listener = new TestDiagnosticListener(); - var context = CreateRouteContext(diagnosticListener: listener); + var context = CreateRouteContext(); - var handler = new MvcRouteHandler(); + var handler = CreateMvcRouteHandler(diagnosticListener: listener); await handler.RouteAsync(context); // Act @@ -156,14 +157,15 @@ public async Task RouteHandler_WritesDiagnostic_ActionInvoked() Assert.NotNull(listener.AfterAction?.HttpContext); } - private RouteContext CreateRouteContext( + private MvcRouteHandler CreateMvcRouteHandler( ActionDescriptor actionDescriptor = null, IActionSelector actionSelector = null, IActionInvokerFactory invokerFactory = null, ILoggerFactory loggerFactory = null, - IOptions optionsAccessor = null, object diagnosticListener = null) { + var actionContextAccessor = new ActionContextAccessor(); + if (actionDescriptor == null) { var mockAction = new Mock(); @@ -175,10 +177,20 @@ private RouteContext CreateRouteContext( var mockActionSelector = new Mock(); mockActionSelector.Setup(a => a.Select(It.IsAny())) .Returns(actionDescriptor); - actionSelector = mockActionSelector.Object; } + if (loggerFactory == null) + { + loggerFactory = NullLoggerFactory.Instance; + } + + var diagnosticSource = new DiagnosticListener("Microsoft.AspNetCore"); + if (diagnosticListener != null) + { + diagnosticSource.SubscribeWithAdapter(diagnosticListener); + } + if (invokerFactory == null) { var mockInvoker = new Mock(); @@ -192,46 +204,19 @@ private RouteContext CreateRouteContext( invokerFactory = mockInvokerFactory.Object; } - if (loggerFactory == null) - { - loggerFactory = NullLoggerFactory.Instance; - } - - if (optionsAccessor == null) - { - optionsAccessor = new TestOptionsManager(); - } - - var diagnosticSource = new DiagnosticListener("Microsoft.AspNetCore"); - if (diagnosticListener != null) - { - diagnosticSource.SubscribeWithAdapter(diagnosticListener); - } + return new MvcRouteHandler( + invokerFactory, + actionSelector, + diagnosticSource, + loggerFactory, + actionContextAccessor); + } + private RouteContext CreateRouteContext() + { var routingFeature = new RoutingFeature(); var httpContext = new Mock(); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(IActionContextAccessor))) - .Returns(new ActionContextAccessor()); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(IActionSelector))) - .Returns(actionSelector); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(IActionInvokerFactory))) - .Returns(invokerFactory); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(ILoggerFactory))) - .Returns(loggerFactory); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(MvcMarkerService))) - .Returns(new MvcMarkerService()); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(IOptions))) - .Returns(optionsAccessor); - httpContext - .Setup(h => h.RequestServices.GetService(typeof(DiagnosticSource))) - .Returns(diagnosticSource); httpContext .Setup(h => h.Features[typeof(IRoutingFeature)]) .Returns(routingFeature);