From f1755de81b0de38c42722fcd16df94e0ea3b9988 Mon Sep 17 00:00:00 2001 From: Artur Signell Date: Fri, 3 Feb 2023 15:08:30 +0200 Subject: [PATCH] fix: Dispatch to the given servlet when using getNamedDispatcher --- .../servlet/AwsProxyRequestDispatcher.java | 7 ++- .../internal/servlet/AwsServletContext.java | 2 +- .../internal/servlet/FilterChainManager.java | 26 +++++++-- .../proxy/internal/testutils/MockServlet.java | 23 ++++++++ .../servlet/AwsFilterChainManagerTest.java | 53 ++++++++++++++++++- .../servlet/AwsServletContextTest.java | 18 +++---- 6 files changed, 111 insertions(+), 18 deletions(-) create mode 100644 aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/MockServlet.java diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java index 27a352ebf..fc52d1dba 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyRequestDispatcher.java @@ -80,7 +80,7 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon } if (isNamedDispatcher) { - lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet((HttpServletRequest)servletRequest)); + lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet(dispatchTo)); return; } @@ -148,4 +148,9 @@ void setRequestPath(ServletRequest req, final String destinationPath) { private Servlet getServlet(HttpServletRequest req) { return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServletForPath(req.getPathInfo()); } + + private Servlet getServlet(String servletName) throws ServletException { + return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServlet(servletName); + } + } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java index 864a9bf10..44ec701ee 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContext.java @@ -171,7 +171,7 @@ public RequestDispatcher getRequestDispatcher(String s) { @Override public RequestDispatcher getNamedDispatcher(String s) { - throw new UnsupportedOperationException(); + return new AwsProxyRequestDispatcher(s, true, containerHandler); } diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java index 92ff55daf..1bf6f4b2a 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/FilterChainManager.java @@ -132,7 +132,7 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servletRegistration), servletContext)); } - putFilterChainCache(type, targetPath, chainHolder); + putFilterChainCache(type, targetPath, servlet, chainHolder); // update total filter size if (filtersSize != registrations.size()) { filtersSize = registrations.size(); @@ -151,13 +151,16 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl * initialized with the cached list of {@link FilterHolder} objects * @param type The dispatcher type for the incoming request * @param targetPath The request path - this is extracted with the getPath method of the request object - * @param servlet Servlet to put at the end of the chain (optional). + * @param servlet The final servlet in the filter chain (if any) * @return A populated FilterChainHolder */ private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) { TargetCacheKey key = new TargetCacheKey(); key.setDispatcherType(type); key.setTargetPath(targetPath); + if (servlet != null) { + key.setServletName(servlet.getServletConfig().getServletName()); + } if (!filterCache.containsKey(key)) { return null; @@ -174,12 +177,16 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S * method to retry this. * @param type DispatcherType from the incoming request * @param targetPath The target path in the API - * @param holder The FilterChainHolder object to save in the cache + * @param servlet The final servlet in the filter chain (if any) + * @param holder The FilterChainHolder object to save in the cache */ - private void putFilterChainCache(final DispatcherType type, final String targetPath, final FilterChainHolder holder) { + private void putFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet, final FilterChainHolder holder) { TargetCacheKey key = new TargetCacheKey(); key.setDispatcherType(type); key.setTargetPath(targetPath); + if (servlet != null) { + key.setServletName(servlet.getServletConfig().getServletName()); + } // we couldn't compute the hash code because either the target path or dispatcher type were null if (key.hashCode() == -1) { @@ -256,6 +263,7 @@ protected static class TargetCacheKey { private String targetPath; private DispatcherType dispatcherType; + private String servletName; //------------------------------------------------------------- @@ -295,10 +303,15 @@ public int hashCode() { } hashString += ":" + hashDispatcher; + if (servletName != null) { + hashString += ":" + servletName; + } + return hashString.hashCode(); } + @Override public boolean equals(Object key) { if (key == null) { @@ -324,6 +337,11 @@ void setTargetPath(String targetPath) { void setDispatcherType(DispatcherType dispatcherType) { this.dispatcherType = dispatcherType; } + public void setServletName(String servletName) { + this.servletName = servletName; + } + + } @SuppressFBWarnings("URF_UNREAD_FIELD") diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/MockServlet.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/MockServlet.java new file mode 100644 index 000000000..b93625bfb --- /dev/null +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/testutils/MockServlet.java @@ -0,0 +1,23 @@ +package com.amazonaws.serverless.proxy.internal.testutils; + +import java.io.IOException; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +public class MockServlet extends HttpServlet { + + private int serviceCalls = 0; + + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + super.service(req, resp); + serviceCalls++; + } + + public int getServiceCalls() { + return serviceCalls; + } +} diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java index 882aa085c..4eedc0f93 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsFilterChainManagerTest.java @@ -2,6 +2,7 @@ import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder; import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext; +import com.amazonaws.serverless.proxy.internal.testutils.MockServlet; import com.amazonaws.services.lambda.runtime.Context; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -17,6 +18,8 @@ import static org.junit.jupiter.api.Assertions.*; public class AwsFilterChainManagerTest { + private static final String SERVLET1_NAME = "Servlet 1"; + private static final String SERVLET2_NAME = "Servlet 2"; private static final String REQUEST_CUSTOM_ATTRIBUTE_NAME = "X-Custom-Attribute"; private static final String REQUEST_CUSTOM_ATTRIBUTE_VALUE = "CustomAttrValue"; @@ -36,6 +39,10 @@ public static void setUp() { reg2.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*"); FilterRegistration.Dynamic reg3 = servletContext.addFilter("Filter3", new MockFilter()); reg3.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/third/fourth/*"); + ServletRegistration.Dynamic firstServlet = servletContext.addServlet(SERVLET1_NAME, new MockServlet()); + firstServlet.addMapping("/first/*"); + ServletRegistration.Dynamic secondServlet = servletContext.addServlet(SERVLET2_NAME, new MockServlet()); + secondServlet.addMapping("/second/*"); chainManager = new AwsFilterChainManager((AwsServletContext) servletContext); } @@ -88,6 +95,22 @@ void cacheKey_compare_differentDispatcher() { assertNotEquals(cacheKey, secondCacheKey); } + @Test + void cacheKey_compare_differentServlet() { + FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey(); + cacheKey.setDispatcherType(DispatcherType.REQUEST); + cacheKey.setTargetPath("/first/path"); + cacheKey.setServletName("Dispatcher servlet"); + + FilterChainManager.TargetCacheKey secondCacheKey = new FilterChainManager.TargetCacheKey(); + secondCacheKey.setDispatcherType(DispatcherType.REQUEST); + secondCacheKey.setTargetPath("/first/path"); + cacheKey.setServletName("Real servlet"); + + assertNotEquals(cacheKey.hashCode(), secondCacheKey.hashCode()); + assertNotEquals(cacheKey, secondCacheKey); + } + @Test void cacheKey_compare_additionalChars() { FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey(); @@ -154,7 +177,7 @@ void filterChain_matchMultipleTimes_expectSameMatch() { } @Test - void filerChain_executeMultipleFilters_expectRunEachTime() { + void filterChain_executeMultipleFilters_expectRunEachTime() { AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest( new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null ); @@ -204,6 +227,34 @@ void filerChain_executeMultipleFilters_expectRunEachTime() { assertEquals(REQUEST_CUSTOM_ATTRIBUTE_VALUE, req2.getAttribute(REQUEST_CUSTOM_ATTRIBUTE_NAME)); } + @Test + void filterChain_multipleServlets_callsCorrectServlet() throws IOException, ServletException { + MockServlet servlet1 = (MockServlet) servletContext.getServlet(SERVLET1_NAME); + ServletConfig servlet1Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET1_NAME)).getServletConfig(); + servlet1.init(servlet1Config); + + MockServlet servlet2 = (MockServlet) servletContext.getServlet(SERVLET2_NAME); + ServletConfig servlet2Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET2_NAME)).getServletConfig(); + servlet2.init(servlet2Config); + + AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest( + new AwsProxyRequestBuilder("/", "GET").build(), lambdaContext, null + ); + AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1)); + + FilterChainHolder servlet1filterChain = chainManager.getFilterChain(req, servlet1); + servlet1filterChain.doFilter(req, resp); + + assertEquals(1, servlet1.getServiceCalls()); + assertEquals(0, servlet2.getServiceCalls()); + + FilterChainHolder servlet2filterChain = chainManager.getFilterChain(req, servlet2); + servlet2filterChain.doFilter(req, resp); + + assertEquals(1, servlet1.getServiceCalls()); + assertEquals(1, servlet2.getServiceCalls()); + } + @Test void filterChain_getFilterChain_multipleFilters() { AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest( diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContextTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContextTest.java index 5c84bca30..35f6723b6 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContextTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsServletContextTest.java @@ -13,12 +13,7 @@ import javax.servlet.http.HttpServletResponse; import java.io.File; -import java.io.FileNotFoundException; import java.io.IOException; -import java.io.PrintWriter; -import java.io.UnsupportedEncodingException; -import java.nio.file.Files; -import java.nio.file.Paths; import java.util.concurrent.CountDownLatch; import static org.junit.jupiter.api.Assertions.*; @@ -190,12 +185,7 @@ void unsupportedOperations_expectExceptions() { } catch (UnsupportedOperationException e) { exCount++; } - try { - STATIC_CTX.getNamedDispatcher("1"); - } catch (UnsupportedOperationException e) { - exCount++; - } - assertEquals(2, exCount); + assertEquals(1, exCount); assertNull(STATIC_CTX.getServletRegistration("1")); } @@ -232,6 +222,12 @@ void addServlet_callsDefaultConstructor() throws ServletException { assertEquals("", ((TestServlet)ctx.getServlet("srv1")).getId()); } + @Test + void getNamedDispatcher_returnsDispatcher() { + AwsServletContext ctx = new AwsServletContext(null); + assertNotNull(ctx.getNamedDispatcher("/hello")); + } + public static class TestServlet implements Servlet { private String id;