From 96874eceaaa2e08fa4950a03a790c1b98a6216f2 Mon Sep 17 00:00:00 2001 From: gregw Date: Wed, 26 Jul 2023 19:18:51 +0200 Subject: [PATCH] Fix #10155 Include mixed writers and input streams Fix #10155 Include mixed writers and input streams --- .../jetty/ee10/servlet/Dispatcher.java | 110 ++++++++++++- .../jetty/ee10/servlet/DispatcherTest.java | 148 +++++++++++++++--- .../test/resources/contextResources/test.txt | 2 +- 3 files changed, 233 insertions(+), 27 deletions(-) diff --git a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/Dispatcher.java b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/Dispatcher.java index 9062e4e8fc82..c73f1961e5b0 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/Dispatcher.java +++ b/jetty-ee10/jetty-ee10-servlet/src/main/java/org/eclipse/jetty/ee10/servlet/Dispatcher.java @@ -14,10 +14,14 @@ package org.eclipse.jetty.ee10.servlet; import java.io.IOException; +import java.io.OutputStream; +import java.io.PrintWriter; +import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -27,6 +31,7 @@ import jakarta.servlet.ServletOutputStream; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; +import jakarta.servlet.WriteListener; import jakarta.servlet.http.HttpServletMapping; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; @@ -35,6 +40,7 @@ import org.eclipse.jetty.ee10.servlet.util.ServletOutputStreamWrapper; import org.eclipse.jetty.http.HttpURI; import org.eclipse.jetty.http.pathmap.MatchedResource; +import org.eclipse.jetty.io.WriterOutputStream; import org.eclipse.jetty.util.MultiMap; import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.UrlEncoded; @@ -136,12 +142,14 @@ public void include(ServletRequest request, ServletResponse response) throws Ser HttpServletResponse httpResponse = (response instanceof HttpServletResponse) ? (HttpServletResponse)response : new ServletResponseHttpWrapper(response); ServletContextResponse servletContextResponse = ServletContextResponse.getServletContextResponse(response); + IncludeResponse includeResponse = new IncludeResponse(httpResponse); try { - _mappedServlet.handle(_servletHandler, _decodedPathInContext, new IncludeRequest(httpRequest), new IncludeResponse(httpResponse)); + _mappedServlet.handle(_servletHandler, _decodedPathInContext, new IncludeRequest(httpRequest), includeResponse); } finally { + includeResponse.flushIfWriting(); servletContextResponse.included(); } } @@ -423,23 +431,107 @@ public String getQueryString() private static class IncludeResponse extends HttpServletResponseWrapper { public static final String JETTY_INCLUDE_HEADER_PREFIX = "org.eclipse.jetty.server.include."; + ServletOutputStream _servletOutputStream; + PrintWriter _printWriter; public IncludeResponse(HttpServletResponse response) { super(response); } + public void flushIfWriting() + { + if (_printWriter != null) + _printWriter.flush(); + } + @Override public ServletOutputStream getOutputStream() throws IOException { - return new ServletOutputStreamWrapper(getResponse().getOutputStream()) + if (_printWriter != null) + throw new IllegalStateException("getWriter() called"); + if (_servletOutputStream == null) { - @Override - public void close() + try { - // NOOP for include. + _servletOutputStream = new ServletOutputStreamWrapper(getResponse().getOutputStream()) + { + @Override + public void close() + { + // NOOP for include. + } + }; } - }; + catch (IllegalStateException ise) + { + OutputStream os = new WriterOutputStream(getResponse().getWriter(), getResponse().getCharacterEncoding()); + _servletOutputStream = new ServletOutputStream() + { + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setWriteListener(WriteListener writeListener) + { + throw new UnsupportedOperationException(); + } + + @Override + public void write(int b) throws IOException + { + os.write(b); + } + + @Override + public void write(byte[] b) throws IOException + { + os.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + os.write(b, off, len); + } + + @Override + public void flush() throws IOException + { + os.flush(); + } + + @Override + public void close() + { + // NOOP for include. + } + }; + } + } + return _servletOutputStream; + } + + @Override + public PrintWriter getWriter() throws IOException + { + if (_servletOutputStream != null) + throw new IllegalStateException("getOutputStream called"); + if (_printWriter == null) + { + try + { + _printWriter = super.getWriter(); + } + catch (IllegalStateException ise) + { + _printWriter = new PrintWriter(super.getOutputStream(), false, Charset.forName(super.getCharacterEncoding())); + } + } + return _printWriter; } @Override @@ -448,6 +540,12 @@ public void setCharacterEncoding(String charset) // NOOP for include. } + @Override + public void setLocale(Locale loc) + { + // NOOP for include. + } + @Override public void setContentLength(int len) { diff --git a/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/DispatcherTest.java b/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/DispatcherTest.java index 70c6f36952dc..a76ad17385f9 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/DispatcherTest.java +++ b/jetty-ee10/jetty-ee10-servlet/src/test/java/org/eclipse/jetty/ee10/servlet/DispatcherTest.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.EnumSet; import java.util.List; +import java.util.stream.Stream; import jakarta.servlet.AsyncContext; import jakarta.servlet.DispatcherType; @@ -61,6 +62,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -380,10 +384,81 @@ public void testInclude() throws Exception HTTP/1.1 200 OK\r specialSetHeader: specialSetHeader\r specialAddHeader: specialAddHeader\r - Content-Length: 7\r + Content-Length: 20\r + Connection: close\r + \r + Include: + INCLUDE--- + """; + + assertEquals(expected, responses); + } + + public static Stream includeTests() + { + return Stream.of( + Arguments.of(false, false), + Arguments.of(false, true), + Arguments.of(true, false), + Arguments.of(true, true) + ); + } + + @ParameterizedTest + @MethodSource("includeTests") + public void testIncludeOutputStreamWriter(boolean includeWriter, boolean helloWriter) throws Exception + { + _contextHandler.addServlet(new ServletHolder(new IncludeServlet(includeWriter)), "/IncludeServlet/*"); + _contextHandler.addServlet(new ServletHolder(new HelloServlet(helloWriter)), "/Hello"); + + //test include, along with special extension to include that allows headers to + //be set during an include + String responses = _connector.getResponse(""" + GET /context/IncludeServlet?do=hello HTTP/1.1\r + Host: local\r + Connection: close\r + \r + """); + + responses = responses.replaceFirst("Content-Length: .*\r\n", ""); + + String expected = """ + HTTP/1.1 200 OK\r + Connection: close\r + \r + Include: + Hello + --- + """; + + assertEquals(expected, responses); + } + + @Test + public void testIncludeWriterOutputStream() throws Exception + { + _contextHandler.addServlet(IncludeServlet.class, "/IncludeServlet/*"); + _contextHandler.addServlet(AssertIncludeServlet.class, "/AssertIncludeServlet/*"); + + //test include, along with special extension to include that allows headers to + //be set during an include + String responses = _connector.getResponse(""" + GET /context/IncludeServlet?do=assertinclude&do=more&test=1&headers=true HTTP/1.1\r + Host: local\r + Connection: close\r + \r + """); + + String expected = """ + HTTP/1.1 200 OK\r + specialSetHeader: specialSetHeader\r + specialAddHeader: specialAddHeader\r + Content-Length: 20\r Connection: close\r \r - INCLUDE"""; + Include: + INCLUDE--- + """; assertEquals(expected, responses); } @@ -404,17 +479,18 @@ public void testIncludeStatic() throws Exception String expected = """ HTTP/1.1 200 OK\r - Content-Length: 26\r + Content-Length: 31\r Connection: close\r \r Include: - Test 2 to too two"""; + Test 2 to too two + --- + """; assertEquals(expected, responses); } @Test - @Disabled("Bug #10155 - response misses the Content-Length header") public void testIncludeStaticWithWriter() throws Exception { _contextHandler.addServlet(new ServletHolder(new IncludeServlet(true)), "/IncludeServlet/*"); @@ -430,11 +506,12 @@ public void testIncludeStaticWithWriter() throws Exception String expected = """ HTTP/1.1 200 OK\r - Content-Length: 26\r Connection: close\r \r Include: - Test 2 to too two"""; + Test 2 to too two + --- + """; assertEquals(expected, responses); } @@ -460,10 +537,11 @@ public void testForwardStatic() throws Exception Last-Modified: xxx\r Content-Type: text/plain\r Accept-Ranges: bytes\r - Content-Length: 17\r + Content-Length: 18\r Connection: close\r \r - Test 2 to too two"""; + Test 2 to too two + """; assertEquals(expected, responses); @@ -558,10 +636,12 @@ public void testForwardThenInclude() throws Exception String expected = """ HTTP/1.1 200 OK\r - Content-Length: 7\r + Content-Length: 20\r Connection: close\r \r - INCLUDE"""; + Include: + INCLUDE--- + """; assertEquals(expected, rawResponse); } @@ -582,10 +662,11 @@ public void testIncludeThenForward() throws Exception String expected = """ HTTP/1.1 200 OK\r - Content-Length: 7\r + Content-Length: 11\r Connection: close\r \r - FORWARD"""; + FORWARD--- + """; assertEquals(expected, rawResponse); } @@ -664,7 +745,6 @@ public void testServletInclude() throws Exception String expected = """ HTTP/1.1 200 OK\r - Content-Length: 11\r \r Roger That!"""; @@ -1078,6 +1158,11 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) t RequestDispatcher dispatcher = null; boolean headers = Boolean.parseBoolean(request.getParameter("headers")); + if (useWriter) + response.getWriter().println("Include:"); + else + response.getOutputStream().write("Include:\n".getBytes(StandardCharsets.US_ASCII)); + if (request.getParameter("do").equals("forward")) dispatcher = getServletContext().getRequestDispatcher("/ForwardServlet/forwardpath?do=assertincludeforward"); else if (request.getParameter("do").equals("assertforwardinclude")) @@ -1085,15 +1170,38 @@ else if (request.getParameter("do").equals("assertforwardinclude")) else if (request.getParameter("do").equals("assertinclude")) dispatcher = getServletContext().getRequestDispatcher("/AssertIncludeServlet?do=end&do=the&headers=" + headers); else if (request.getParameter("do").equals("static")) - { - if (useWriter) - response.getWriter().println("Include:"); - else - response.getOutputStream().write("Include:\n".getBytes(StandardCharsets.US_ASCII)); dispatcher = getServletContext().getRequestDispatcher("/test.txt"); - } + else if (request.getParameter("do").equals("hello")) + dispatcher = getServletContext().getRequestDispatcher("/Hello"); + assert dispatcher != null; + dispatcher.include(request, response); + + if (useWriter) + response.getWriter().println("---"); + else + response.getOutputStream().write("---\n".getBytes(StandardCharsets.US_ASCII)); + } + } + + public static class HelloServlet extends HttpServlet implements Servlet + { + // The logic linked to this field be deleted and the writer always used once #10155 is fixed. + private final boolean useWriter; + + public HelloServlet(boolean useWriter) + { + this.useWriter = useWriter; + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + if (useWriter) + response.getWriter().println("Hello"); + else + response.getOutputStream().write("Hello\n".getBytes(StandardCharsets.US_ASCII)); } } diff --git a/jetty-ee10/jetty-ee10-servlet/src/test/resources/contextResources/test.txt b/jetty-ee10/jetty-ee10-servlet/src/test/resources/contextResources/test.txt index cf3582b72034..4518b58d2aaa 100644 --- a/jetty-ee10/jetty-ee10-servlet/src/test/resources/contextResources/test.txt +++ b/jetty-ee10/jetty-ee10-servlet/src/test/resources/contextResources/test.txt @@ -1 +1 @@ -Test 2 to too two \ No newline at end of file +Test 2 to too two