diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java index d17270d98..b76fd216e 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsHttpServletRequest.java @@ -76,6 +76,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest { static final String PROTOCOL_HEADER_NAME = "X-Forwarded-Proto"; static final String HOST_HEADER_NAME = "Host"; static final String PORT_HEADER_NAME = "X-Forwarded-Port"; + static final String CLIENT_IP_HEADER = "X-Forwarded-For"; //------------------------------------------------------------- diff --git a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java index 3e92d43dc..17d04a57c 100644 --- a/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java +++ b/aws-serverless-java-container-core/src/main/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequest.java @@ -21,13 +21,15 @@ import com.amazonaws.serverless.proxy.model.RequestSource; import com.amazonaws.services.lambda.runtime.Context; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import jakarta.servlet.*; -import jakarta.servlet.http.*; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpUpgradeHandler; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.SecurityContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.IOException; @@ -435,12 +437,22 @@ public String getRemoteAddr() { if (request.getRequestContext() == null || request.getRequestContext().getIdentity() == null) { return "127.0.0.1"; } + if (request.getRequestContext().getElb() != null) { + return request.getHeaders().get(CLIENT_IP_HEADER); + } return request.getRequestContext().getIdentity().getSourceIp(); } @Override public String getRemoteHost() { + if (Objects.nonNull(request.getRequestContext().getElb())) { + String hostHeader = request.getHeaders().get(HttpHeaders.HOST); + + // the host header has the form host:port, so we split the string to get the host part + return Arrays.asList(hostHeader.split(":")).get(0); + } + return request.getMultiValueHeaders().getFirst(HttpHeaders.HOST); } @@ -471,6 +483,12 @@ public RequestDispatcher getRequestDispatcher(String s) { @Override public int getRemotePort() { + if (Objects.nonNull(request.getRequestContext().getElb())) { + String portHeader = request.getHeaders().get(PORT_HEADER_NAME); + if (Objects.nonNull(portHeader)) { + return Integer.parseInt(portHeader); + } + } return 0; } diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java index 76507508b..45c38843e 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/servlet/AwsProxyHttpServletRequestTest.java @@ -651,6 +651,18 @@ void serverName_albHostHeader_returnsHostHeader() { assertEquals("testapi.us-east-1.elb.amazonaws.com", serverName); } + @Test + void getRemoteHost_albHostHeader_returnsHostHeader() { + initAwsProxyHttpServletRequestTest("ALB"); + AwsProxyRequest proxyReq = new AwsProxyRequestBuilder("/test", "GET") + .alb().build(); + proxyReq.getHeaders().put(HttpHeaders.HOST, "testapi.us-east-1.elb.amazonaws.com"); + HttpServletRequest servletRequest = new AwsProxyHttpServletRequest(proxyReq, null, null); + + String host = servletRequest.getRemoteHost(); + assertEquals("testapi.us-east-1.elb.amazonaws.com", host); + } + private AwsProxyRequestBuilder getRequestWithHeaders() { return new AwsProxyRequestBuilder("/hello", "GET") .header(CUSTOM_HEADER_KEY, CUSTOM_HEADER_VALUE) diff --git a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java index 9df66a891..40f0ae7ad 100644 --- a/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java +++ b/aws-serverless-java-container-core/src/test/java/com/amazonaws/serverless/proxy/internal/testutils/AwsProxyRequestBuilder.java @@ -70,7 +70,8 @@ public AwsProxyRequestBuilder(AwsProxyRequest req) { public AwsProxyRequestBuilder(String path, String httpMethod) { this.request = new AwsProxyRequest(); - this.request.setMultiValueHeaders(new Headers()); // avoid NPE + this.request.setMultiValueHeaders(new Headers());// avoid NPE + this.request.setHeaders(new SingleValueHeaders()); this.request.setHttpMethod(httpMethod); this.request.setPath(path); this.request.setMultiValueQueryStringParameters(new MultiValuedTreeMap<>());