diff --git a/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java index 7619c3bb6f5dd..57eaa2bab363b 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java +++ b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java @@ -43,6 +43,7 @@ import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.Enumeration; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -128,6 +129,20 @@ protected void parseURI(String requestURI, HttpServletRequest request, HttpServl OutputBufferId bufferId = null; long token = 0; + if (request != null) { + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + String headerValue = request.getHeader(headerName); + if (headerName.contains("\r") || headerName.contains("\n")) { + throw new IllegalArgumentException(format("Invalid header name: %s", headerName)); + } + if (headerValue.contains("\r") || headerValue.contains("\n")) { + throw new IllegalArgumentException(format("Invalid header value: %s", headerValue)); + } + } + } + int previousIndex = -1; for (int part = 0; part < 8; part++) { int nextIndex = requestURI.indexOf('/', previousIndex + 1); diff --git a/presto-tests/src/test/java/com/facebook/presto/server/TestAsyncPageTransportServlet.java b/presto-tests/src/test/java/com/facebook/presto/server/TestAsyncPageTransportServlet.java index d67dd943a7479..aa567502bd7fc 100644 --- a/presto-tests/src/test/java/com/facebook/presto/server/TestAsyncPageTransportServlet.java +++ b/presto-tests/src/test/java/com/facebook/presto/server/TestAsyncPageTransportServlet.java @@ -15,6 +15,10 @@ import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import javax.servlet.http.HttpServletRequest; @@ -23,6 +27,7 @@ import java.io.IOException; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.fail; @Test(singleThreaded = true) @@ -34,6 +39,7 @@ class TestServlet TaskId taskId; OutputBufferId bufferId; String requestURI; + HttpServletRequest request; long token; void parse(String uri) throws IOException @@ -41,6 +47,11 @@ void parse(String uri) throws IOException parseURI(uri, null, null); } + void parse(String uri, HttpServletRequest request) throws IOException + { + parseURI(uri, request, null); + } + @Override protected void processRequest( String requestURI, TaskId taskId, OutputBufferId bufferId, long token, @@ -50,6 +61,7 @@ protected void processRequest( this.taskId = taskId; this.bufferId = bufferId; this.token = token; + this.request = request; } @Override @@ -80,6 +92,27 @@ public void testParsing() assertEquals(789, servlet.token); } + @DataProvider(name = "testSanitizationProvider") + public Object[][] testSanitizationProvider() + { + return new Object[][] { + {"ke\ny", "value"}, + {"key", "valu\ne"}, + {"ke\ry", "value"}, + {"key", "valu\re"}}; + } + + @Test(dataProvider = "testSanitizationProvider") + public void testSanitization(String key, String value) + { + ListMultimap headers = ImmutableListMultimap.of(key, value); + HttpServletRequest request = new MockHttpServletRequest(headers, "", ImmutableMap.of()); + TestServlet servlet = new TestServlet(); + assertThrows( + IllegalArgumentException.class, + () -> { servlet.parse("/v1/task/async/0.1.2.3.4/results/456/789", request); }); + } + @Test (expectedExceptions = { IllegalArgumentException.class }) public void testParseTooFewElements() {