Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,22 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ReadListener;
import jakarta.servlet.RequestDispatcher;
import jakarta.servlet.ServletConnection;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletRequestAttributeEvent;
import jakarta.servlet.ServletRequestAttributeListener;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.WriteListener;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletMapping;
import jakarta.servlet.http.HttpServletRequest;
Expand All @@ -63,10 +60,7 @@
import jakarta.servlet.http.HttpUpgradeHandler;
import jakarta.servlet.http.Part;
import jakarta.servlet.http.PushBuilder;
import jakarta.servlet.http.WebConnection;
import org.eclipse.jetty.ee11.servlet.ServletContextHandler.ServletRequestInfo;
import org.eclipse.jetty.ee11.servlet.util.ServletInputStreamWrapper;
import org.eclipse.jetty.ee11.servlet.util.ServletOutputStreamWrapper;
import org.eclipse.jetty.http.BadMessageException;
import org.eclipse.jetty.http.CookieCompliance;
import org.eclipse.jetty.http.HttpCookie;
Expand All @@ -81,7 +75,6 @@
import org.eclipse.jetty.http.MimeTypes;
import org.eclipse.jetty.http.SetCookieParser;
import org.eclipse.jetty.http.pathmap.MatchedResource;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.QuietException;
import org.eclipse.jetty.security.AuthenticationState;
import org.eclipse.jetty.security.UserIdentity;
Expand Down Expand Up @@ -757,201 +750,20 @@ public Part getPart(String name) throws IOException, ServletException
public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException
{
Response response = _servletContextRequest.getServletContextResponse();
if (response.getStatus() != HttpStatus.SWITCHING_PROTOCOLS_101)
throw new IllegalStateException("Response status should be 101");
if (response.getHeaders().get("Upgrade") == null)
throw new IllegalStateException("Missing Upgrade header");
if (!"Upgrade".equalsIgnoreCase(response.getHeaders().get("Connection")))
throw new IllegalStateException("Invalid Connection header");
if (response.isCommitted())
throw new IllegalStateException("Cannot upgrade committed response");
throw new ServletException("Cannot upgrade committed response");
if (_servletChannel.getConnectionMetaData().getHttpVersion() != HttpVersion.HTTP_1_1)
throw new IllegalStateException("Only requests over HTTP/1.1 can be upgraded");
throw new ServletException("Only requests over HTTP/1.1 can be upgraded");

CompletableFuture<Void> outputStreamComplete = new CompletableFuture<>();
CompletableFuture<Void> inputStreamComplete = new CompletableFuture<>();
ServletOutputStream outputStream = new ServletOutputStreamWrapper(_servletContextRequest.getHttpOutput())
{
@Override
public void write(int b) throws IOException
{
try
{
super.write(b);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void write(byte[] b) throws IOException
{
try
{
super.write(b);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void write(byte[] b, int off, int len) throws IOException
{
try
{
super.write(b, off, len);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void close() throws IOException
{
try
{
super.close();
outputStreamComplete.complete(null);
}
catch (Throwable t)
{
outputStreamComplete.completeExceptionally(t);
throw t;
}
}
response.setStatus(HttpStatus.SWITCHING_PROTOCOLS_101);
response.getHeaders().put(HttpHeader.CONNECTION, HttpHeaderValue.UPGRADE);

@Override
public void setWriteListener(WriteListener writeListener)
{
super.setWriteListener(new WriteListener()
{
@Override
public void onWritePossible() throws IOException
{
writeListener.onWritePossible();
}

@Override
public void onError(Throwable t)
{
writeListener.onError(t);
outputStreamComplete.completeExceptionally(t);
}
});
}
};
ServletInputStream inputStream = new ServletInputStreamWrapper(_servletContextRequest.getHttpInput())
{
@Override
public int read() throws IOException
{
try
{
int read = super.read();
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public int read(byte[] b) throws IOException
{
try
{
int read = super.read(b);
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public int read(byte[] b, int off, int len) throws IOException
{
try
{
int read = super.read(b, off, len);
if (read == -1)
inputStreamComplete.complete(null);
return read;
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void close() throws IOException
{
try
{
super.close();
inputStreamComplete.complete(null);
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void setReadListener(ReadListener readListener)
{
super.setReadListener(new ReadListener()
{
@Override
public void onDataAvailable() throws IOException
{
readListener.onDataAvailable();
}

@Override
public void onAllDataRead() throws IOException
{
try
{
readListener.onAllDataRead();
inputStreamComplete.complete(null);
}
catch (Throwable t)
{
inputStreamComplete.completeExceptionally(t);
throw t;
}
}

@Override
public void onError(Throwable t)
{
readListener.onError(t);
inputStreamComplete.completeExceptionally(t);
}
});
}
};
// Use the first protocol from the request Upgrade header.
Optional<String> upgradeProtocol = _servletContextRequest.getHeaders()
.getCSV(HttpHeader.UPGRADE, false).stream().findFirst();
if (upgradeProtocol.isEmpty() || upgradeProtocol.get().isBlank())
throw new ServletException("Missing Upgrade header");
response.getHeaders().put(HttpHeader.UPGRADE, upgradeProtocol.get());

T upgradeHandler;
try
Expand All @@ -963,48 +775,7 @@ public void onError(Throwable t)
throw new ServletException("Unable to instantiate handler class", e);
}

Connection connection = _servletContextRequest.getConnectionMetaData().getConnection();
if (connection instanceof Connection.Tunnel upgradeableConnection)
{
outputStream.flush(); // commit the 101 response
upgradeableConnection.startTunnel();
}
else
{
LOG.warn("Unexpected connection type {}", connection);
throw new IllegalStateException();
}

AsyncContext asyncContext = forceStartAsync(); // force the servlet in async mode
CompletableFuture.allOf(inputStreamComplete, outputStreamComplete).whenComplete((result, failure) ->
{
upgradeHandler.destroy();
asyncContext.complete();
});

WebConnection webConnection = new WebConnection()
{
@Override
public void close() throws Exception
{
IO.close(inputStream);
IO.close(outputStream);
}

@Override
public ServletInputStream getInputStream()
{
return inputStream;
}

@Override
public ServletOutputStream getOutputStream()
{
return outputStream;
}
};

upgradeHandler.init(webConnection);
_servletChannel.upgrade(upgradeHandler);
return upgradeHandler;
}

Expand Down Expand Up @@ -1661,28 +1432,28 @@ public AsyncContext startAsync() throws IllegalStateException
return forceStartAsync();
}

private AsyncContext forceStartAsync()
@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException
{
ServletChannelState state = getServletRequestInfo().getState();
if (_async == null)
_async = new AsyncContextState(state);
if (!isAsyncSupported())
throw new IllegalStateException("Async Not Supported");
return forceStartAsync(servletRequest, servletResponse);
}

// We must remember the request as last dispatched by the container so that we can use its uri for
// possible subsequent dispatch
public AsyncContext forceStartAsync()
{
ServletRequestInfo servletRequestInfo = getServletRequestInfo();
AsyncContextEvent event = new AsyncContextEvent(getServletRequestInfo().getServletContext(), _async, state, servletRequestInfo.getServletChannel().getServletContextRequest().getServletApiRequest(), servletRequestInfo.getServletChannel().getServletContextResponse().getServletApiResponse());
state.startAsync(event);
return _async;
ServletRequest servletRequest = servletRequestInfo.getServletChannel().getServletContextRequest().getServletApiRequest();
ServletResponse servletResponse = servletRequestInfo.getServletChannel().getServletContextResponse().getServletApiResponse();
return forceStartAsync(servletRequest, servletResponse);
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException
public AsyncContext forceStartAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException
{
if (!isAsyncSupported())
throw new IllegalStateException("Async Not Supported");
ServletChannelState state = getServletRequestInfo().getState();
if (_async == null)
_async = new AsyncContextState(state);

//We must remember the request and response passed in for use in a possible subsequent dispatch
AsyncContextEvent event = new AsyncContextEvent(getServletRequestInfo().getServletContext(), _async, state, servletRequest, servletResponse);
state.startAsync(event);
Expand Down
Loading