Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Dispatch to the given servlet when using getNamedDispatcher #502

Merged
merged 1 commit into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void forward(ServletRequest servletRequest, ServletResponse servletRespon
}

if (isNamedDispatcher) {
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet((HttpServletRequest)servletRequest));
lambdaContainerHandler.doFilter((HttpServletRequest) servletRequest, (HttpServletResponse) servletResponse, getServlet(dispatchTo));
return;
}

Expand Down Expand Up @@ -148,4 +148,9 @@ void setRequestPath(ServletRequest req, final String destinationPath) {
private Servlet getServlet(HttpServletRequest req) {
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServletForPath(req.getPathInfo());
}

private Servlet getServlet(String servletName) throws ServletException {
return ((AwsServletContext)lambdaContainerHandler.getServletContext()).getServlet(servletName);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public RequestDispatcher getRequestDispatcher(String s) {

@Override
public RequestDispatcher getNamedDispatcher(String s) {
throw new UnsupportedOperationException();
return new AwsProxyRequestDispatcher(s, true, containerHandler);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
chainHolder.addFilter(new FilterHolder(new ServletExecutionFilter(servletRegistration), servletContext));
}

putFilterChainCache(type, targetPath, chainHolder);
putFilterChainCache(type, targetPath, servlet, chainHolder);
// update total filter size
if (filtersSize != registrations.size()) {
filtersSize = registrations.size();
Expand All @@ -151,13 +151,16 @@ FilterChainHolder getFilterChain(final HttpServletRequest request, Servlet servl
* initialized with the cached list of {@link FilterHolder} objects
* @param type The dispatcher type for the incoming request
* @param targetPath The request path - this is extracted with the <code>getPath</code> method of the request object
* @param servlet Servlet to put at the end of the chain (optional).
* @param servlet The final servlet in the filter chain (if any)
* @return A populated FilterChainHolder
*/
private FilterChainHolder getFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet) {
TargetCacheKey key = new TargetCacheKey();
key.setDispatcherType(type);
key.setTargetPath(targetPath);
if (servlet != null) {
key.setServletName(servlet.getServletConfig().getServletName());
}

if (!filterCache.containsKey(key)) {
return null;
Expand All @@ -174,12 +177,16 @@ private FilterChainHolder getFilterChainCache(final DispatcherType type, final S
* method to retry this.
* @param type DispatcherType from the incoming request
* @param targetPath The target path in the API
* @param holder The FilterChainHolder object to save in the cache
* @param servlet The final servlet in the filter chain (if any)
* @param holder The FilterChainHolder object to save in the cache
*/
private void putFilterChainCache(final DispatcherType type, final String targetPath, final FilterChainHolder holder) {
private void putFilterChainCache(final DispatcherType type, final String targetPath, Servlet servlet, final FilterChainHolder holder) {
TargetCacheKey key = new TargetCacheKey();
key.setDispatcherType(type);
key.setTargetPath(targetPath);
if (servlet != null) {
key.setServletName(servlet.getServletConfig().getServletName());
}

// we couldn't compute the hash code because either the target path or dispatcher type were null
if (key.hashCode() == -1) {
Expand Down Expand Up @@ -256,6 +263,7 @@ protected static class TargetCacheKey {

private String targetPath;
private DispatcherType dispatcherType;
private String servletName;


//-------------------------------------------------------------
Expand Down Expand Up @@ -295,10 +303,15 @@ public int hashCode() {
}
hashString += ":" + hashDispatcher;

if (servletName != null) {
hashString += ":" + servletName;
}

return hashString.hashCode();
}



@Override
public boolean equals(Object key) {
if (key == null) {
Expand All @@ -324,6 +337,11 @@ void setTargetPath(String targetPath) {
void setDispatcherType(DispatcherType dispatcherType) {
this.dispatcherType = dispatcherType;
}
public void setServletName(String servletName) {
this.servletName = servletName;
}


}

@SuppressFBWarnings("URF_UNREAD_FIELD")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.amazonaws.serverless.proxy.internal.testutils;

import java.io.IOException;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class MockServlet extends HttpServlet {

private int serviceCalls = 0;

@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.service(req, resp);
serviceCalls++;
}

public int getServiceCalls() {
return serviceCalls;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.amazonaws.serverless.proxy.internal.testutils.AwsProxyRequestBuilder;
import com.amazonaws.serverless.proxy.internal.testutils.MockLambdaContext;
import com.amazonaws.serverless.proxy.internal.testutils.MockServlet;
import com.amazonaws.services.lambda.runtime.Context;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
Expand All @@ -17,6 +18,8 @@
import static org.junit.jupiter.api.Assertions.*;

public class AwsFilterChainManagerTest {
private static final String SERVLET1_NAME = "Servlet 1";
private static final String SERVLET2_NAME = "Servlet 2";
private static final String REQUEST_CUSTOM_ATTRIBUTE_NAME = "X-Custom-Attribute";
private static final String REQUEST_CUSTOM_ATTRIBUTE_VALUE = "CustomAttrValue";

Expand All @@ -36,6 +39,10 @@ public static void setUp() {
reg2.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/second/*");
FilterRegistration.Dynamic reg3 = servletContext.addFilter("Filter3", new MockFilter());
reg3.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/third/fourth/*");
ServletRegistration.Dynamic firstServlet = servletContext.addServlet(SERVLET1_NAME, new MockServlet());
firstServlet.addMapping("/first/*");
ServletRegistration.Dynamic secondServlet = servletContext.addServlet(SERVLET2_NAME, new MockServlet());
secondServlet.addMapping("/second/*");

chainManager = new AwsFilterChainManager((AwsServletContext) servletContext);
}
Expand Down Expand Up @@ -88,6 +95,22 @@ void cacheKey_compare_differentDispatcher() {
assertNotEquals(cacheKey, secondCacheKey);
}

@Test
void cacheKey_compare_differentServlet() {
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
cacheKey.setDispatcherType(DispatcherType.REQUEST);
cacheKey.setTargetPath("/first/path");
cacheKey.setServletName("Dispatcher servlet");

FilterChainManager.TargetCacheKey secondCacheKey = new FilterChainManager.TargetCacheKey();
secondCacheKey.setDispatcherType(DispatcherType.REQUEST);
secondCacheKey.setTargetPath("/first/path");
cacheKey.setServletName("Real servlet");

assertNotEquals(cacheKey.hashCode(), secondCacheKey.hashCode());
assertNotEquals(cacheKey, secondCacheKey);
}

@Test
void cacheKey_compare_additionalChars() {
FilterChainManager.TargetCacheKey cacheKey = new FilterChainManager.TargetCacheKey();
Expand Down Expand Up @@ -154,7 +177,7 @@ void filterChain_matchMultipleTimes_expectSameMatch() {
}

@Test
void filerChain_executeMultipleFilters_expectRunEachTime() {
void filterChain_executeMultipleFilters_expectRunEachTime() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
new AwsProxyRequestBuilder("/first/second", "GET").build(), lambdaContext, null
);
Expand Down Expand Up @@ -204,6 +227,34 @@ void filerChain_executeMultipleFilters_expectRunEachTime() {
assertEquals(REQUEST_CUSTOM_ATTRIBUTE_VALUE, req2.getAttribute(REQUEST_CUSTOM_ATTRIBUTE_NAME));
}

@Test
void filterChain_multipleServlets_callsCorrectServlet() throws IOException, ServletException {
MockServlet servlet1 = (MockServlet) servletContext.getServlet(SERVLET1_NAME);
ServletConfig servlet1Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET1_NAME)).getServletConfig();
servlet1.init(servlet1Config);

MockServlet servlet2 = (MockServlet) servletContext.getServlet(SERVLET2_NAME);
ServletConfig servlet2Config = ((AwsServletRegistration) servletContext.getServletRegistration(SERVLET2_NAME)).getServletConfig();
servlet2.init(servlet2Config);

AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
new AwsProxyRequestBuilder("/", "GET").build(), lambdaContext, null
);
AwsHttpServletResponse resp = new AwsHttpServletResponse(req, new CountDownLatch(1));

FilterChainHolder servlet1filterChain = chainManager.getFilterChain(req, servlet1);
servlet1filterChain.doFilter(req, resp);

assertEquals(1, servlet1.getServiceCalls());
assertEquals(0, servlet2.getServiceCalls());

FilterChainHolder servlet2filterChain = chainManager.getFilterChain(req, servlet2);
servlet2filterChain.doFilter(req, resp);

assertEquals(1, servlet1.getServiceCalls());
assertEquals(1, servlet2.getServiceCalls());
}

@Test
void filterChain_getFilterChain_multipleFilters() {
AwsProxyHttpServletRequest req = new AwsProxyHttpServletRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
import javax.servlet.http.HttpServletResponse;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.concurrent.CountDownLatch;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -190,12 +185,7 @@ void unsupportedOperations_expectExceptions() {
} catch (UnsupportedOperationException e) {
exCount++;
}
try {
STATIC_CTX.getNamedDispatcher("1");
} catch (UnsupportedOperationException e) {
exCount++;
}
assertEquals(2, exCount);
assertEquals(1, exCount);

assertNull(STATIC_CTX.getServletRegistration("1"));
}
Expand Down Expand Up @@ -232,6 +222,12 @@ void addServlet_callsDefaultConstructor() throws ServletException {
assertEquals("", ((TestServlet)ctx.getServlet("srv1")).getId());
}

@Test
void getNamedDispatcher_returnsDispatcher() {
AwsServletContext ctx = new AwsServletContext(null);
assertNotNull(ctx.getNamedDispatcher("/hello"));
}

public static class TestServlet implements Servlet {
private String id;

Expand Down