From 7726d0502beb7c249876818558430dd64bcda5b4 Mon Sep 17 00:00:00 2001 From: Lachlan Roberts Date: Wed, 17 Sep 2025 18:47:12 +1000 Subject: [PATCH] Issue #13335 - fixes for ServletUpgrade in EE11 Signed-off-by: Lachlan Roberts --- .../jetty/ee11/servlet/ServletApiRequest.java | 279 ++--------------- .../jetty/ee11/servlet/ServletChannel.java | 289 +++++++++++++++++- .../ee11/servlet/ServletChannelState.java | 23 +- .../ee11/servlet/ServletUpgradeTest.java | 186 ++++++++--- 4 files changed, 471 insertions(+), 306 deletions(-) diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java index 0f8f45c3e25d..d27ec7eee914 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletApiRequest.java @@ -36,25 +36,22 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import jakarta.servlet.AsyncContext; import jakarta.servlet.DispatcherType; -import jakarta.servlet.ReadListener; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletConnection; import jakarta.servlet.ServletContext; import jakarta.servlet.ServletException; import jakarta.servlet.ServletInputStream; -import jakarta.servlet.ServletOutputStream; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequestAttributeEvent; import jakarta.servlet.ServletRequestAttributeListener; import jakarta.servlet.ServletResponse; -import jakarta.servlet.WriteListener; import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletMapping; import jakarta.servlet.http.HttpServletRequest; @@ -63,10 +60,7 @@ import jakarta.servlet.http.HttpUpgradeHandler; import jakarta.servlet.http.Part; import jakarta.servlet.http.PushBuilder; -import jakarta.servlet.http.WebConnection; import org.eclipse.jetty.ee11.servlet.ServletContextHandler.ServletRequestInfo; -import org.eclipse.jetty.ee11.servlet.util.ServletInputStreamWrapper; -import org.eclipse.jetty.ee11.servlet.util.ServletOutputStreamWrapper; import org.eclipse.jetty.http.BadMessageException; import org.eclipse.jetty.http.CookieCompliance; import org.eclipse.jetty.http.HttpCookie; @@ -81,7 +75,6 @@ import org.eclipse.jetty.http.MimeTypes; import org.eclipse.jetty.http.SetCookieParser; import org.eclipse.jetty.http.pathmap.MatchedResource; -import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.QuietException; import org.eclipse.jetty.security.AuthenticationState; import org.eclipse.jetty.security.UserIdentity; @@ -757,201 +750,20 @@ public Part getPart(String name) throws IOException, ServletException public T upgrade(Class handlerClass) throws IOException, ServletException { Response response = _servletContextRequest.getServletContextResponse(); - if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101) - throw new IllegalStateException("Response status should be 101"); - if (response.getHeaders().get("Upgrade") == null) - throw new IllegalStateException("Missing Upgrade header"); - if (!"Upgrade".equalsIgnoreCase(response.getHeaders().get("Connection"))) - throw new IllegalStateException("Invalid Connection header"); if (response.isCommitted()) - throw new IllegalStateException("Cannot upgrade committed response"); + throw new ServletException("Cannot upgrade committed response"); if (_servletChannel.getConnectionMetaData().getHttpVersion() != HttpVersion.HTTP_1_1) - throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded"); + throw new ServletException("Only requests over HTTP/1.1 can be upgraded"); - CompletableFuture outputStreamComplete = new CompletableFuture<>(); - CompletableFuture inputStreamComplete = new CompletableFuture<>(); - ServletOutputStream outputStream = new ServletOutputStreamWrapper(_servletContextRequest.getHttpOutput()) - { - @Override - public void write(int b) throws IOException - { - try - { - super.write(b); - } - catch (Throwable t) - { - outputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void write(byte[] b) throws IOException - { - try - { - super.write(b); - } - catch (Throwable t) - { - outputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void write(byte[] b, int off, int len) throws IOException - { - try - { - super.write(b, off, len); - } - catch (Throwable t) - { - outputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void close() throws IOException - { - try - { - super.close(); - outputStreamComplete.complete(null); - } - catch (Throwable t) - { - outputStreamComplete.completeExceptionally(t); - throw t; - } - } + response.setStatus(HttpStatus.SWITCHING_PROTOCOLS_101); + response.getHeaders().put(HttpHeader.CONNECTION, HttpHeaderValue.UPGRADE); - @Override - public void setWriteListener(WriteListener writeListener) - { - super.setWriteListener(new WriteListener() - { - @Override - public void onWritePossible() throws IOException - { - writeListener.onWritePossible(); - } - - @Override - public void onError(Throwable t) - { - writeListener.onError(t); - outputStreamComplete.completeExceptionally(t); - } - }); - } - }; - ServletInputStream inputStream = new ServletInputStreamWrapper(_servletContextRequest.getHttpInput()) - { - @Override - public int read() throws IOException - { - try - { - int read = super.read(); - if (read == -1) - inputStreamComplete.complete(null); - return read; - } - catch (Throwable t) - { - inputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public int read(byte[] b) throws IOException - { - try - { - int read = super.read(b); - if (read == -1) - inputStreamComplete.complete(null); - return read; - } - catch (Throwable t) - { - inputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public int read(byte[] b, int off, int len) throws IOException - { - try - { - int read = super.read(b, off, len); - if (read == -1) - inputStreamComplete.complete(null); - return read; - } - catch (Throwable t) - { - inputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void close() throws IOException - { - try - { - super.close(); - inputStreamComplete.complete(null); - } - catch (Throwable t) - { - inputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void setReadListener(ReadListener readListener) - { - super.setReadListener(new ReadListener() - { - @Override - public void onDataAvailable() throws IOException - { - readListener.onDataAvailable(); - } - - @Override - public void onAllDataRead() throws IOException - { - try - { - readListener.onAllDataRead(); - inputStreamComplete.complete(null); - } - catch (Throwable t) - { - inputStreamComplete.completeExceptionally(t); - throw t; - } - } - - @Override - public void onError(Throwable t) - { - readListener.onError(t); - inputStreamComplete.completeExceptionally(t); - } - }); - } - }; + // Use the first protocol from the request Upgrade header. + Optional upgradeProtocol = _servletContextRequest.getHeaders() + .getCSV(HttpHeader.UPGRADE, false).stream().findFirst(); + if (upgradeProtocol.isEmpty() || upgradeProtocol.get().isBlank()) + throw new ServletException("Missing Upgrade header"); + response.getHeaders().put(HttpHeader.UPGRADE, upgradeProtocol.get()); T upgradeHandler; try @@ -963,48 +775,7 @@ public void onError(Throwable t) throw new ServletException("Unable to instantiate handler class", e); } - Connection connection = _servletContextRequest.getConnectionMetaData().getConnection(); - if (connection instanceof Connection.Tunnel upgradeableConnection) - { - outputStream.flush(); // commit the 101 response - upgradeableConnection.startTunnel(); - } - else - { - LOG.warn("Unexpected connection type {}", connection); - throw new IllegalStateException(); - } - - AsyncContext asyncContext = forceStartAsync(); // force the servlet in async mode - CompletableFuture.allOf(inputStreamComplete, outputStreamComplete).whenComplete((result, failure) -> - { - upgradeHandler.destroy(); - asyncContext.complete(); - }); - - WebConnection webConnection = new WebConnection() - { - @Override - public void close() throws Exception - { - IO.close(inputStream); - IO.close(outputStream); - } - - @Override - public ServletInputStream getInputStream() - { - return inputStream; - } - - @Override - public ServletOutputStream getOutputStream() - { - return outputStream; - } - }; - - upgradeHandler.init(webConnection); + _servletChannel.upgrade(upgradeHandler); return upgradeHandler; } @@ -1661,28 +1432,28 @@ public AsyncContext startAsync() throws IllegalStateException return forceStartAsync(); } - private AsyncContext forceStartAsync() + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { - ServletChannelState state = getServletRequestInfo().getState(); - if (_async == null) - _async = new AsyncContextState(state); + if (!isAsyncSupported()) + throw new IllegalStateException("Async Not Supported"); + return forceStartAsync(servletRequest, servletResponse); + } - // We must remember the request as last dispatched by the container so that we can use its uri for - // possible subsequent dispatch + public AsyncContext forceStartAsync() + { ServletRequestInfo servletRequestInfo = getServletRequestInfo(); - AsyncContextEvent event = new AsyncContextEvent(getServletRequestInfo().getServletContext(), _async, state, servletRequestInfo.getServletChannel().getServletContextRequest().getServletApiRequest(), servletRequestInfo.getServletChannel().getServletContextResponse().getServletApiResponse()); - state.startAsync(event); - return _async; + ServletRequest servletRequest = servletRequestInfo.getServletChannel().getServletContextRequest().getServletApiRequest(); + ServletResponse servletResponse = servletRequestInfo.getServletChannel().getServletContextResponse().getServletApiResponse(); + return forceStartAsync(servletRequest, servletResponse); } - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException + public AsyncContext forceStartAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { - if (!isAsyncSupported()) - throw new IllegalStateException("Async Not Supported"); ServletChannelState state = getServletRequestInfo().getState(); if (_async == null) _async = new AsyncContextState(state); + //We must remember the request and response passed in for use in a possible subsequent dispatch AsyncContextEvent event = new AsyncContextEvent(getServletRequestInfo().getServletContext(), _async, state, servletRequest, servletResponse); state.startAsync(event); diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannel.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannel.java index 894a3e7dc45b..d7a20dbb8bb1 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannel.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannel.java @@ -16,18 +16,28 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.util.Objects; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ReadListener; import jakarta.servlet.RequestDispatcher; import jakarta.servlet.ServletContext; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.WriteListener; +import jakarta.servlet.http.HttpUpgradeHandler; +import jakarta.servlet.http.WebConnection; import org.eclipse.jetty.ee11.servlet.ServletChannelState.Action; +import org.eclipse.jetty.ee11.servlet.util.ServletInputStreamWrapper; +import org.eclipse.jetty.ee11.servlet.util.ServletOutputStreamWrapper; import org.eclipse.jetty.http.BadMessageException; import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.HttpStatus; import org.eclipse.jetty.http.HttpURI; +import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.EndPoint; import org.eclipse.jetty.io.QuietException; @@ -43,6 +53,7 @@ import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.ExceptionUtil; import org.eclipse.jetty.util.HostPort; +import org.eclipse.jetty.util.IO; import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.TypeUtil; import org.eclipse.jetty.util.URIUtil; @@ -85,6 +96,7 @@ public class ServletChannel private Response _response; private Callback _callback; private boolean _completeAttempted; + private HttpUpgradeHandler _upgradeHandler; public ServletChannel(ServletContextHandler servletContextHandler, Request request) { @@ -557,8 +569,21 @@ public void handle() break; } + case UPGRADE: + { + doUpgrade(_upgradeHandler); + break; + } + case COMPLETE: { + if (_upgradeHandler != null) + { + AsyncContextEvent asyncContextEvent = _state.getAsyncContextEvent(); + asyncContextEvent.getAsyncContext().complete(); + _upgradeHandler.destroy(); + } + ServletContextResponse response = getServletContextResponse(); if (!response.isCommitted()) { @@ -609,6 +634,268 @@ public void handle() LOG.debug("!handle {} {}", action, this); } + public void upgrade(HttpUpgradeHandler upgradeHandler) + { + _upgradeHandler = upgradeHandler; + _state.upgrade(); + } + + private void doUpgrade(HttpUpgradeHandler upgradeHandler) throws IOException + { + { + Response response = _servletContextRequest.getServletContextResponse(); + if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101) + throw new IllegalStateException("Response status should be 101"); + if (response.getHeaders().get("Upgrade") == null) + throw new IllegalStateException("Missing Upgrade header"); + if (!"Upgrade".equalsIgnoreCase(response.getHeaders().get("Connection"))) + throw new IllegalStateException("Invalid Connection header"); + if (response.isCommitted()) + throw new IllegalStateException("Cannot upgrade committed response"); + if (getConnectionMetaData().getHttpVersion() != HttpVersion.HTTP_1_1) + throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded"); + + CompletableFuture outputStreamComplete = new CompletableFuture<>(); + CompletableFuture inputStreamComplete = new CompletableFuture<>(); + ServletOutputStream outputStream = new ServletOutputStreamWrapper(_servletContextRequest.getHttpOutput()) + { + @Override + public void write(int b) throws IOException + { + try + { + super.write(b); + } + catch (Throwable t) + { + outputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void write(byte[] b) throws IOException + { + try + { + super.write(b); + } + catch (Throwable t) + { + outputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + try + { + super.write(b, off, len); + } + catch (Throwable t) + { + outputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void flush() throws IOException + { + try + { + super.flush(); + } + catch (Throwable t) + { + outputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void close() throws IOException + { + try + { + super.close(); + outputStreamComplete.complete(null); + } + catch (Throwable t) + { + outputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void setWriteListener(WriteListener writeListener) + { + super.setWriteListener(new WriteListener() + { + @Override + public void onWritePossible() throws IOException + { + writeListener.onWritePossible(); + } + + @Override + public void onError(Throwable t) + { + writeListener.onError(t); + outputStreamComplete.completeExceptionally(t); + } + }); + } + }; + ServletInputStream inputStream = new ServletInputStreamWrapper(_servletContextRequest.getHttpInput()) + { + @Override + public int read() throws IOException + { + try + { + int read = super.read(); + if (read == -1) + inputStreamComplete.complete(null); + return read; + } + catch (Throwable t) + { + inputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public int read(byte[] b) throws IOException + { + try + { + int read = super.read(b); + if (read == -1) + inputStreamComplete.complete(null); + return read; + } + catch (Throwable t) + { + inputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException + { + try + { + int read = super.read(b, off, len); + if (read == -1) + inputStreamComplete.complete(null); + return read; + } + catch (Throwable t) + { + inputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void close() throws IOException + { + try + { + super.close(); + inputStreamComplete.complete(null); + } + catch (Throwable t) + { + inputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void setReadListener(ReadListener readListener) + { + super.setReadListener(new ReadListener() + { + @Override + public void onDataAvailable() throws IOException + { + readListener.onDataAvailable(); + } + + @Override + public void onAllDataRead() throws IOException + { + try + { + readListener.onAllDataRead(); + inputStreamComplete.complete(null); + } + catch (Throwable t) + { + inputStreamComplete.completeExceptionally(t); + throw t; + } + } + + @Override + public void onError(Throwable t) + { + readListener.onError(t); + inputStreamComplete.completeExceptionally(t); + } + }); + } + }; + + AsyncContext asyncContext = _servletContextRequest.getServletApiRequest().forceStartAsync(); + CompletableFuture.allOf(inputStreamComplete, outputStreamComplete).whenComplete((result, failure) -> + asyncContext.complete()); + + Connection connection = _servletContextRequest.getConnectionMetaData().getConnection(); + if (connection instanceof Connection.Tunnel upgradeableConnection) + { + outputStream.flush(); // commit the 101 response + upgradeableConnection.startTunnel(); + } + else + { + LOG.warn("Unexpected connection type {}", connection); + throw new IllegalStateException(); + } + + WebConnection webConnection = new WebConnection() + { + @Override + public void close() + { + IO.close(inputStream); + IO.close(outputStream); + } + + @Override + public ServletInputStream getInputStream() + { + return inputStream; + } + + @Override + public ServletOutputStream getOutputStream() + { + return outputStream; + } + }; + + upgradeHandler.init(webConnection); + } + } + private void reopen() { _servletContextRequest.getServletContextResponse().getHttpOutput().reopen(); diff --git a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java index 01603f1cb80a..9af367ee4a61 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java +++ b/jetty-ee11/jetty-ee11-servlet/src/main/java/org/eclipse/jetty/ee11/servlet/ServletChannelState.java @@ -93,6 +93,7 @@ private enum RequestState DISPATCH, // AsyncContext.dispatch() has been called EXPIRE, // AsyncContext timeout has happened EXPIRING, // AsyncListeners are being called + UPGRADING, // Request is being upgraded COMPLETE, // AsyncContext.complete() has been called COMPLETING, // Request is being closed (maybe asynchronously) COMPLETED // Response is completed @@ -170,6 +171,7 @@ public enum Action ASYNC_TIMEOUT, // call asyncContext onTimeout WRITE_CALLBACK, // handle an IO write callback READ_CALLBACK, // handle an IO read callback + UPGRADE, // Complete the response by closing output COMPLETE, // Complete the response by closing output TERMINATED, // No further actions WAIT, // Wait for further events @@ -557,6 +559,11 @@ private Action nextAction(boolean handling) _sendError = false; return Action.SEND_ERROR; + case UPGRADING: + if (handling) + throw new IllegalStateException(getStatusStringLocked()); + return Action.UPGRADE; + case COMPLETE: _requestState = RequestState.COMPLETING; return Action.COMPLETE; @@ -583,7 +590,7 @@ public void startAsync(AsyncContextEvent event) { if (LOG.isDebugEnabled()) LOG.debug("startAsync {}", toStringLocked()); - if (_state != State.HANDLING || (_requestState != RequestState.BLOCKING && _requestState != RequestState.ERRORING)) + if (_state != State.HANDLING || (_requestState != RequestState.BLOCKING && _requestState != RequestState.ERRORING && _requestState != RequestState.UPGRADING)) throw new IllegalStateException(this.getStatusStringLocked()); if (!_failureListener) @@ -1204,23 +1211,11 @@ protected void recycle() public void upgrade() { - cancelTimeout(); try (AutoLock ignored = lock()) { if (LOG.isDebugEnabled()) LOG.debug("upgrade {}", toStringLocked()); - - if (_state != State.IDLE) - throw new IllegalStateException(getStatusStringLocked()); - if (_inputState != InputState.IDLE) - throw new IllegalStateException(getStatusStringLocked()); - _asyncListeners = null; - _state = State.UPGRADED; - _requestState = RequestState.BLOCKING; - _initial = true; - _asyncWritePossible = false; - _timeoutMs = DEFAULT_TIMEOUT; - _event = null; + _requestState = RequestState.UPGRADING; } } diff --git a/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/ServletUpgradeTest.java b/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/ServletUpgradeTest.java index 3fb9eee21c9e..f43ae992cad1 100644 --- a/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/ServletUpgradeTest.java +++ b/jetty-ee11/jetty-ee11-servlet/src/test/java/org/eclipse/jetty/ee11/servlet/ServletUpgradeTest.java @@ -17,8 +17,11 @@ import java.io.InputStream; import java.io.OutputStream; import java.net.Socket; +import java.net.SocketException; +import java.time.Duration; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import jakarta.servlet.ReadListener; @@ -30,20 +33,23 @@ import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpUpgradeHandler; import jakarta.servlet.http.WebConnection; +import org.eclipse.jetty.http.HttpHeader; +import org.eclipse.jetty.http.HttpStatus; +import org.eclipse.jetty.io.EofException; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; -import org.eclipse.jetty.util.Utf8StringBuilder; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.awaitility.Awaitility.await; import static org.eclipse.jetty.util.StringUtil.CRLF; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.startsWith; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class ServletUpgradeTest @@ -52,13 +58,9 @@ public class ServletUpgradeTest private Server server; private int port; - private static CountDownLatch destroyLatch; - @BeforeEach - public void setUp() throws Exception + public void setUp(HttpServlet servlet) throws Exception { - destroyLatch = new CountDownLatch(1); - server = new Server(); ServerConnector connector = new ServerConnector(server); @@ -66,7 +68,7 @@ public void setUp() throws Exception ServletContextHandler contextHandler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS); contextHandler.setContextPath("/"); - contextHandler.addServlet(new ServletHolder(new TestServlet()), "/TestServlet"); + contextHandler.addServlet(new ServletHolder(servlet), "/"); server.setHandler(contextHandler); @@ -83,6 +85,24 @@ public void tearDown() throws Exception @Test public void upgradeTest() throws Exception { + CompletableFuture futureUpgradeHandler = new CompletableFuture<>(); + setUp(new HttpServlet() + { + public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + TestHttpUpgradeHandler handler = request.upgrade(TestHttpUpgradeHandler.class); + futureUpgradeHandler.complete(handler); + + // The call to upgrade() automatically sets the required response status and headers. + assertThat(response.getStatus(), equalTo(HttpStatus.SWITCHING_PROTOCOLS_101)); + assertThat(response.getHeader(HttpHeader.CONNECTION.asString()), equalTo(HttpHeader.UPGRADE.asString())); + assertThat(response.getHeader(HttpHeader.UPGRADE.asString()), equalTo("YES")); + + // Assert that init has not been called yet. + assertThat(handler.initLatch.getCount(), equalTo(1L)); + } + }); + Socket socket = new Socket("localhost", port); socket.setSoTimeout(0); InputStream input = socket.getInputStream(); @@ -98,13 +118,12 @@ public void upgradeTest() throws Exception writeChunk(output, "Hello"); writeChunk(output, "World"); output.flush(); - socket.shutdownOutput(); + StringBuffer sb = new StringBuffer(); CompletableFuture futureContent = new CompletableFuture<>(); new Thread(() -> { LOG.info("Consuming the response from the server"); - Utf8StringBuilder sb = new Utf8StringBuilder(); try { while (true) @@ -112,56 +131,136 @@ public void upgradeTest() throws Exception int read = input.read(); if (read == -1) break; - sb.append((byte)read); + sb.append((char)read); } - futureContent.complete(sb.toCompleteString()); + futureContent.complete(sb.toString()); } catch (Throwable t) { LOG.warn("failed with content: " + sb, t); futureContent.completeExceptionally(t); } - }).start(); - String content = futureContent.get(5, TimeUnit.SECONDS); - String expectedContent = """ + // Wait until we get the echoed content. + await().atMost(Duration.ofSeconds(5)).pollDelay(Duration.ofMillis(200)) + .until(() -> + { + System.err.println("testing: " + sb); + return sb.toString().contains("HelloWorld"); + }); + + // The destroy latch is only counted down after the connection is closed. + TestHttpUpgradeHandler handler = futureUpgradeHandler.get(5, TimeUnit.SECONDS); + assertThat(handler.destroyLatch.getCount(), equalTo(1L)); + socket.shutdownOutput(); + assertTrue(handler.destroyLatch.await(5, TimeUnit.SECONDS)); + + String fullContent = futureContent.get(5, TimeUnit.SECONDS); + assertThat(fullContent, containsString("HTTP/1.1 101 Switching Protocols")); + assertThat(fullContent, containsString("Connection: Upgrade")); + assertThat(fullContent, containsString("Upgrade: YES")); + assertThat(fullContent, containsString(""" TCKHttpUpgradeHandler.init\r =onDataAvailable\r HelloWorld\r =onAllDataRead\r - """; - assertThat(content, startsWith("HTTP/1.1 101 Switching Protocols")); - assertThat(content, endsWith(expectedContent)); + """)); - input.close(); - output.close(); socket.close(); - assertTrue(destroyLatch.await(5, TimeUnit.SECONDS)); + assertTrue(handler.destroyLatch.await(5, TimeUnit.SECONDS)); } - private static class TestServlet extends HttpServlet + @Test + public void testEarlyEof() throws Exception { - public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + CompletableFuture futureUpgradeHandler = new CompletableFuture<>(); + setUp(new HttpServlet() { - if (request.getHeader("Upgrade") != null) + public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - response.setStatus(101); - response.setHeader("Upgrade", "YES"); - response.setHeader("Connection", "Upgrade"); TestHttpUpgradeHandler handler = request.upgrade(TestHttpUpgradeHandler.class); - assertThat(handler, instanceOf(TestHttpUpgradeHandler.class)); + futureUpgradeHandler.complete(handler); + + // The call to upgrade() automatically sets the required response status and headers. + assertThat(response.getStatus(), equalTo(HttpStatus.SWITCHING_PROTOCOLS_101)); + assertThat(response.getHeader(HttpHeader.CONNECTION.asString()), equalTo(HttpHeader.UPGRADE.asString())); + assertThat(response.getHeader(HttpHeader.UPGRADE.asString()), equalTo("YES")); + + // Assert that init has not been called yet. + assertThat(handler.initLatch.getCount(), equalTo(1L)); + } + }); + + Socket socket = new Socket("localhost", port); + socket.setSoTimeout(0); + InputStream input = socket.getInputStream(); + OutputStream output = socket.getOutputStream(); + + String request = "POST /TestServlet HTTP/1.1" + CRLF + + "Host: localhost:" + port + CRLF + + "Upgrade: YES" + CRLF + + "Connection: Upgrade" + CRLF + + CRLF; + + output.write(request.getBytes()); + writeChunk(output, "Hello"); + writeChunk(output, "World"); + output.flush(); + + StringBuffer sb = new StringBuffer(); + CompletableFuture futureContent = new CompletableFuture<>(); + new Thread(() -> + { + LOG.info("Consuming the response from the server"); + try + { + while (true) + { + int read = input.read(); + if (read == -1) + break; + sb.append((char)read); + } + futureContent.complete(sb.toString()); } - else + catch (Throwable t) { - response.getWriter().println("No upgrade"); - response.getWriter().println("End of Test"); + futureContent.completeExceptionally(t); } - } + }).start(); + + // Wait until we get the echoed content. + await().atMost(Duration.ofSeconds(5)).pollDelay(Duration.ofMillis(200)).until(() -> sb.toString().contains("HelloWorld")); + String content = sb.toString(); + assertThat(content, containsString("HTTP/1.1 101 Switching Protocols")); + assertThat(content, containsString("Connection: Upgrade")); + assertThat(content, containsString("Upgrade: YES")); + assertThat(content, containsString(""" + TCKHttpUpgradeHandler.init\r + =onDataAvailable\r + HelloWorld""")); + + // The HttpUpgradeHandler.destroy() should still be called in case of an error. + socket.setSoLinger(true, 0); + socket.close(); + TestHttpUpgradeHandler handler = futureUpgradeHandler.get(5, TimeUnit.SECONDS); + assertTrue(handler.destroyLatch.await(5, TimeUnit.SECONDS)); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> futureContent.get(5, TimeUnit.SECONDS)); + assertThat(exception.getCause(), instanceOf(SocketException.class)); + assertThat(exception.getCause().getMessage(), containsString("Socket closed")); + + Throwable throwable = handler.errorFuture.get(5, TimeUnit.SECONDS); + assertThat(throwable, instanceOf(EofException.class)); } public static class TestHttpUpgradeHandler implements HttpUpgradeHandler { + public CountDownLatch initLatch = new CountDownLatch(1); + public CountDownLatch destroyLatch = new CountDownLatch(1); + public CompletableFuture errorFuture = new CompletableFuture<>(); + public TestHttpUpgradeHandler() { } @@ -179,7 +278,7 @@ public void init(WebConnection wc) { ServletInputStream input = wc.getInputStream(); ServletOutputStream output = wc.getOutputStream(); - TestReadListener readListener = new TestReadListener(input, output); + TestReadListener readListener = new TestReadListener(this, input, output); input.setReadListener(readListener); output.println("TCKHttpUpgradeHandler.init"); output.flush(); @@ -188,17 +287,28 @@ public void init(WebConnection wc) { throw new RuntimeException(ex); } + finally + { + initLatch.countDown(); + } + } + + public void onError(Throwable t) + { + errorFuture.complete(t); } } private static class TestReadListener implements ReadListener { + private final TestHttpUpgradeHandler upgradeHandler; private final ServletInputStream input; private final ServletOutputStream output; private boolean outputOnDataAvailable = false; - TestReadListener(ServletInputStream in, ServletOutputStream out) + TestReadListener(TestHttpUpgradeHandler upgradeHandler, ServletInputStream in, ServletOutputStream out) { + this.upgradeHandler = upgradeHandler; input = in; output = out; } @@ -213,6 +323,7 @@ public void onAllDataRead() } catch (Exception ex) { + upgradeHandler.onError(ex); throw new IllegalStateException(ex); } } @@ -241,6 +352,7 @@ public void onDataAvailable() } catch (Exception ex) { + upgradeHandler.onError(ex); throw new IllegalStateException(ex); } } @@ -248,7 +360,7 @@ public void onDataAvailable() @Override public void onError(final Throwable t) { - LOG.error("TestReadListener error", t); + upgradeHandler.onError(t); } }