Skip to content

Commit

Permalink
fix: Dispatch to the given servlet when using getNamedDispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur- committed Feb 3, 2023
1 parent be15cfd commit f1755de
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon
}

if (isNamedDispatcher) {
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet((HttpServletRequest)servletRequest));
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet(dispatchTo));
return;
}

Expand Down Expand Up @@ -148,4 +148,9 @@ void setRequestPath(ServletRequest req, final String destinationPath) {
private Servlet getServlet(HttpServletRequest req) {
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServletForPath(req.getPathInfo());
}

private Servlet getServlet(String servletName) throws ServletException {
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServlet(servletName);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public RequestDispatcher getRequestDispatcher(String s) {

@Override
public RequestDispatcher getNamedDispatcher(String s) {
throw new UnsupportedOperationException();
return new AwsProxyRequestDispatcher(s, true, containerHandler);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servletRegistration), servletContext));
}

putFilterChainCache(type, targetPath, chainHolder);
putFilterChainCache(type, targetPath, servlet, chainHolder);
// update total filter size
if (filtersSize != registrations.size()) {
filtersSize = registrations.size();
Expand All @@ -151,13 +151,16 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
* initialized with the cached list of {@link FilterHolder} objects
* @param type The dispatcher type for the incoming request
* @param targetPath The request path - this is extracted with the <code>getPath</code> method of the request object
* @param servlet Servlet to put at the end of the chain (optional).
* @param servlet The final servlet in the filter chain (if any)
* @return A populated FilterChainHolder
*/
private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) {
TargetCacheKey key = new TargetCacheKey();
key.setDispatcherType(type);
key.setTargetPath(targetPath);
if (servlet != null) {
key.setServletName(servlet.getServletConfig().getServletName());
}

if (!filterCache.containsKey(key)) {
return null;
Expand All @@ -174,12 +177,16 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S
* method to retry this.
* @param type DispatcherType from the incoming request
* @param targetPath The target path in the API
* @param holder The FilterChainHolder object to save in the cache
* @param servlet The final servlet in the filter chain (if any)
* @param holder The FilterChainHolder object to save in the cache
*/
private void putFilterChainCache(final DispatcherType type, final String targetPath, final FilterChainHolder holder) {
private void putFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet, final FilterChainHolder holder) {
TargetCacheKey key = new TargetCacheKey();
key.setDispatcherType(type);
key.setTargetPath(targetPath);
if (servlet != null) {
key.setServletName(servlet.getServletConfig().getServletName());
}

// we couldn't compute the hash code because either the target path or dispatcher type were null
if (key.hashCode() == -1) {
Expand Down Expand Up @@ -256,6 +263,7 @@ protected static class TargetCacheKey {

private String targetPath;
private DispatcherType dispatcherType;
private String servletName;


//-------------------------------------------------------------
Expand Down Expand Up @@ -295,10 +303,15 @@ public int hashCode() {
}
hashString += ":" + hashDispatcher;

if (servletName != null) {
hashString += ":" + servletName;
}

return hashString.hashCode();
}



@Override
public boolean equals(Object key) {
if (key == null) {
Expand All @@ -324,6 +337,11 @@ void setTargetPath(String targetPath) {
void setDispatcherType(DispatcherType dispatcherType) {
this.dispatcherType = dispatcherType;
}
public void setServletName(String servletName) {
this.servletName = servletName;
}


}

@SuppressFBWarnings("URF_UNREAD_FIELD")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.amazonaws.serverless.proxy.internal.testutils;

import java.io.IOException;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class MockServlet extends HttpServlet {

private int serviceCalls = 0;

@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.service(req, resp);
serviceCalls++;
}

public int getServiceCalls() {
return serviceCalls;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.internal.testutils.MockServlet;
import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand All @@ -17,6 +18,8 @@
import static org.junit.jupiter.api.Assertions.*;

public class AwsFilterChainManagerTest {
private static final String SERVLET1_NAME = "Servlet 1";
private static final String SERVLET2_NAME = "Servlet 2";
private static final String REQUEST_CUSTOM_ATTRIBUTE_NAME = "X-Custom-Attribute";
private static final String REQUEST_CUSTOM_ATTRIBUTE_VALUE = "CustomAttrValue";

Expand All @@ -36,6 +39,10 @@ public static void setUp() {
reg2.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*");
FilterRegistration.Dynamic reg3 = servletContext.addFilter("Filter3", new MockFilter());
reg3.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/third/fourth/*");
ServletRegistration.Dynamic firstServlet = servletContext.addServlet(SERVLET1_NAME, new MockServlet());
firstServlet.addMapping("/first/*");
ServletRegistration.Dynamic secondServlet = servletContext.addServlet(SERVLET2_NAME, new MockServlet());
secondServlet.addMapping("/second/*");

chainManager = new AwsFilterChainManager((AwsServletContext) servletContext);
}
Expand Down Expand Up @@ -88,6 +95,22 @@ void cacheKey_compare_differentDispatcher() {
assertNotEquals(cacheKey, secondCacheKey);
}

@Test
void cacheKey_compare_differentServlet() {
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
cacheKey.setDispatcherType(DispatcherType.REQUEST);
cacheKey.setTargetPath("/first/path");
cacheKey.setServletName("Dispatcher servlet");

FilterChainManager.TargetCacheKey secondCacheKey = new FilterChainManager.TargetCacheKey();
secondCacheKey.setDispatcherType(DispatcherType.REQUEST);
secondCacheKey.setTargetPath("/first/path");
cacheKey.setServletName("Real servlet");

assertNotEquals(cacheKey.hashCode(), secondCacheKey.hashCode());
assertNotEquals(cacheKey, secondCacheKey);
}

@Test
void cacheKey_compare_additionalChars() {
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
Expand Down Expand Up @@ -154,7 +177,7 @@ void filterChain_matchMultipleTimes_expectSameMatch() {
}

@Test
void filerChain_executeMultipleFilters_expectRunEachTime() {
void filterChain_executeMultipleFilters_expectRunEachTime() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null
);
Expand Down Expand Up @@ -204,6 +227,34 @@ void filerChain_executeMultipleFilters_expectRunEachTime() {
assertEquals(REQUEST_CUSTOM_ATTRIBUTE_VALUE, req2.getAttribute(REQUEST_CUSTOM_ATTRIBUTE_NAME));
}

@Test
void filterChain_multipleServlets_callsCorrectServlet() throws IOException, ServletException {
MockServlet servlet1 = (MockServlet) servletContext.getServlet(SERVLET1_NAME);
ServletConfig servlet1Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET1_NAME)).getServletConfig();
servlet1.init(servlet1Config);

MockServlet servlet2 = (MockServlet) servletContext.getServlet(SERVLET2_NAME);
ServletConfig servlet2Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET2_NAME)).getServletConfig();
servlet2.init(servlet2Config);

AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
new AwsProxyRequestBuilder("/", "GET").build(), lambdaContext, null
);
AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1));

FilterChainHolder servlet1filterChain = chainManager.getFilterChain(req, servlet1);
servlet1filterChain.doFilter(req, resp);

assertEquals(1, servlet1.getServiceCalls());
assertEquals(0, servlet2.getServiceCalls());

FilterChainHolder servlet2filterChain = chainManager.getFilterChain(req, servlet2);
servlet2filterChain.doFilter(req, resp);

assertEquals(1, servlet1.getServiceCalls());
assertEquals(1, servlet2.getServiceCalls());
}

@Test
void filterChain_getFilterChain_multipleFilters() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
import javax.servlet.http.HttpServletResponse;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.concurrent.CountDownLatch;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -190,12 +185,7 @@ void unsupportedOperations_expectExceptions() {
} catch (UnsupportedOperationException e) {
exCount++;
}
try {
STATIC_CTX.getNamedDispatcher("1");
} catch (UnsupportedOperationException e) {
exCount++;
}
assertEquals(2, exCount);
assertEquals(1, exCount);

assertNull(STATIC_CTX.getServletRegistration("1"));
}
Expand Down Expand Up @@ -232,6 +222,12 @@ void addServlet_callsDefaultConstructor() throws ServletException {
assertEquals("", ((TestServlet)ctx.getServlet("srv1")).getId());
}

@Test
void getNamedDispatcher_returnsDispatcher() {
AwsServletContext ctx = new AwsServletContext(null);
assertNotNull(ctx.getNamedDispatcher("/hello"));
}

public static class TestServlet implements Servlet {
private String id;

Expand Down

0 comments on commit f1755de

Please sign in to comment.