diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java index 8091c2768136..fe6e6d74a9f6 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/WebSocketCoreSession.java @@ -54,7 +54,7 @@ public class WebSocketCoreSession implements CoreSession, Dumpable private final WebSocketComponents components; private final Behavior behavior; - private final WebSocketSessionState sessionState = new WebSocketSessionState(); + private final WebSocketSessionState sessionState; private final FrameHandler handler; private final Negotiated negotiated; private final Flusher flusher = new Flusher(this); @@ -76,6 +76,7 @@ public class WebSocketCoreSession implements CoreSession, Dumpable public WebSocketCoreSession(FrameHandler handler, Behavior behavior, Negotiated negotiated, WebSocketComponents components) { + this.sessionState = new WebSocketSessionState(behavior); this.classLoader = Thread.currentThread().getContextClassLoader(); this.components = components; this.handler = handler; @@ -219,16 +220,20 @@ public void onEof() if (LOG.isDebugEnabled()) LOG.debug("onEof() {}", this); - if (sessionState.onEof()) - closeConnection(sessionState.getCloseStatus(), Callback.NOOP); + WebSocketSessionState.EofResult result = sessionState.onEof(); + if (result.shutdownOutput()) + getConnection().getEndPoint().shutdownOutput(); + if (result.closeEndpoint()) + abort(); + if (result.notifyWebSocketClose()) + notifyWebSocketConnectionClose(sessionState.getCloseStatus(), Callback.NOOP); } - private void closeConnection(CloseStatus closeStatus, Callback callback) + private void notifyWebSocketConnectionClose(CloseStatus closeStatus, Callback callback) { if (LOG.isDebugEnabled()) LOG.debug("closeConnection() {} {}", closeStatus, this); - abort(); extensionStack.close(); // Forward Errors to Local WebSocket EndPoint @@ -335,9 +340,13 @@ private void processError(CloseStatus closeStatus, Callback callback) } else { - if (sessionState.onClosed(closeStatus)) + WebSocketSessionState.Result result = sessionState.onClosed(closeStatus); + if (result.closeEndpoint()) + abort(); + + if (result.notifyWebSocketClose()) { - closeConnection(closeStatus, callback); + notifyWebSocketConnectionClose(closeStatus, callback); } else { @@ -506,13 +515,39 @@ public void sendFrame(OutgoingEntry entry) if (LOG.isDebugEnabled()) LOG.debug("sendFrame({}, {}, {})", frame, callback, batch); - boolean closeConnection = sessionState.onOutgoingFrame(frame); - if (closeConnection) + WebSocketSessionState.Result result = sessionState.onOutgoingFrame(frame); + callback = Callback.from(callback, failure -> + { + if (failure != null) + { + CloseStatus closeStatus = new CloseStatus(CloseStatus.NO_CLOSE, failure); + WebSocketSessionState.Result closeResult = sessionState.onClosed(closeStatus); + if (closeResult.closeEndpoint()) + abort(); + if (closeResult.notifyWebSocketClose()) + notifyWebSocketConnectionClose(closeStatus, NOOP); + return; + } + + if (frame.getOpCode() == OpCode.CLOSE) + { + WebSocketSessionState.CloseResult closeResult = sessionState.onCloseFrameSent(); + if (closeResult.shutdownOutput()) + connection.getEndPoint().shutdownOutput(); + if (closeResult.closeEndpoint()) + abort(); + } + + if (result.closeEndpoint()) + abort(); + }); + + if (result.notifyWebSocketClose()) { Callback c = callback; Callback closeConnectionCallback = Callback.from( - () -> closeConnection(sessionState.getCloseStatus(), c), - t -> closeConnection(sessionState.getCloseStatus(), Callback.from(c, t))); + () -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), c), + t -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), Callback.from(c, t))); flusher.sendFrame(new OutgoingEntry.Builder(entry) .callback(closeConnectionCallback) @@ -533,8 +568,17 @@ public void sendFrame(OutgoingEntry entry) if (frame.getOpCode() == OpCode.CLOSE) { CloseStatus closeStatus = CloseStatus.getCloseStatus(frame); - if (closeStatus.isAbnormal() && sessionState.onClosed(closeStatus)) - closeConnection(closeStatus, Callback.from(callback, t)); + if (closeStatus.isAbnormal()) + { + WebSocketSessionState.Result result = sessionState.onClosed(closeStatus); + if (result.closeEndpoint()) + abort(); + + if (result.notifyWebSocketClose()) + notifyWebSocketConnectionClose(closeStatus, Callback.from(callback, t)); + else + callback.failed(t); + } else callback.failed(t); } @@ -656,7 +700,9 @@ public void onFrame(Frame frame, Callback callback) if (LOG.isDebugEnabled()) LOG.debug("receiveFrame({}, {}) - connectionState={}, handler={}", frame, callback, sessionState, handler); - boolean closeConnection = sessionState.onIncomingFrame(frame); + WebSocketSessionState.Result result = sessionState.onIncomingFrame(frame); + if (result.closeEndpoint()) + abort(); // Handle inbound frame if (frame.getOpCode() != OpCode.CLOSE) @@ -665,14 +711,15 @@ public void onFrame(Frame frame, Callback callback) return; } - // Handle inbound CLOSE + // Cancel demand to read to EOF, as we cannot receive any more frames after the CLOSE Frame. connection.cancelDemand(); - if (closeConnection) + if (result.notifyWebSocketClose()) { - closeCallback = Callback.from(() -> closeConnection(sessionState.getCloseStatus(), callback), t -> + closeCallback = Callback.from(() -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), callback), t -> { - sessionState.onError(t); - closeConnection(sessionState.getCloseStatus(), callback); + if (sessionState.onError(t)) + abort(); + notifyWebSocketConnectionClose(sessionState.getCloseStatus(), callback); }); } else diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java index 6fda274f3ab6..8c8150020151 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/FrameFlusher.java @@ -389,8 +389,6 @@ protected void onSuccess() for (FlusherEntry entry : _completedEntries) { - if (entry.getFrame().getOpCode() == OpCode.CLOSE && _behavior == Behavior.SERVER) - _endPoint.shutdownOutput(); notifyCallbackSuccess(entry.getCallback()); } _completedEntries.clear(); diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java index cc42e1dd6258..848d801bc010 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/internal/WebSocketSessionState.java @@ -17,6 +17,7 @@ import org.eclipse.jetty.util.TypeUtil; import org.eclipse.jetty.util.thread.AutoLock; +import org.eclipse.jetty.websocket.core.Behavior; import org.eclipse.jetty.websocket.core.CloseStatus; import org.eclipse.jetty.websocket.core.Frame; import org.eclipse.jetty.websocket.core.OpCode; @@ -27,7 +28,7 @@ */ public class WebSocketSessionState { - enum State + enum WebSocketState { CONNECTING, CONNECTED, @@ -37,31 +38,50 @@ enum State CLOSED } - private final AutoLock lock = new AutoLock(); - private State _sessionState = State.CONNECTING; + enum EndPointState + { + OPEN, + ISHUT, + OSHUT, + CLOSED + } + + public record Result(boolean notifyWebSocketClose, boolean closeEndpoint) {} + public record EofResult(boolean notifyWebSocketClose, boolean closeEndpoint, boolean shutdownOutput){} + public record CloseResult(boolean shutdownOutput, boolean closeEndpoint){} + + private final AutoLock _lock = new AutoLock(); + private final Behavior _behavior; + private WebSocketState _webSocketState = WebSocketState.CONNECTING; + private EndPointState _endPointState = EndPointState.OPEN; private byte _incomingContinuation = OpCode.UNDEFINED; private byte _outgoingContinuation = OpCode.UNDEFINED; CloseStatus _closeStatus = null; + public WebSocketSessionState(Behavior behavior) + { + _behavior = behavior; + } + public void onConnected() { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { - if (_sessionState != State.CONNECTING) - throw new IllegalStateException(_sessionState.toString()); + if (_webSocketState != WebSocketState.CONNECTING) + throw new IllegalStateException(_webSocketState.toString()); - _sessionState = State.CONNECTED; + _webSocketState = WebSocketState.CONNECTED; } } public void onOpen() { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { - switch (_sessionState) + switch (_webSocketState) { case CONNECTED: - _sessionState = State.OPEN; + _webSocketState = WebSocketState.OPEN; break; case OSHUT: @@ -70,57 +90,44 @@ public void onOpen() break; default: - throw new IllegalStateException(_sessionState.toString()); + throw new IllegalStateException(_webSocketState.toString()); } } } - private State getState() + private WebSocketState getState() { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { - return _sessionState; + return _webSocketState; } } public boolean isClosed() { - return getState() == State.CLOSED; + return getState() == WebSocketState.CLOSED; } public boolean isInputOpen() { - State state = getState(); - return (state == State.OPEN || state == State.OSHUT); + WebSocketState state = getState(); + return (state == WebSocketState.OPEN || state == WebSocketState.OSHUT); } public boolean isOutputOpen() { - State state = getState(); - return (state == State.CONNECTED || state == State.OPEN || state == State.ISHUT); + WebSocketState state = getState(); + return (state == WebSocketState.CONNECTED || state == WebSocketState.OPEN || state == WebSocketState.ISHUT); } public CloseStatus getCloseStatus() { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { return _closeStatus; } } - public boolean onClosed(CloseStatus closeStatus) - { - try (AutoLock l = lock.lock()) - { - if (_sessionState == State.CLOSED) - return false; - - _closeStatus = closeStatus; - _sessionState = State.CLOSED; - return true; - } - } - /** *
* If no error is set in the CloseStatus this will either, replace the current close status with @@ -129,19 +136,20 @@ public boolean onClosed(CloseStatus closeStatus) *
** This should only be called if there is an error directly before the call to - * {@code WebSocketCoreSession.closeConnection(CloseStatus, Callback)}. + * {@code WebSocketCoreSession#notifyWebSocketConnectionClose(CloseStatus, Callback)}. *
** This could occur if the FrameHandler throws an exception in onFrame after receiving a close frame reply, in this * case to notify onError we must set the cause in the closeStatus. *
* @param t the error which occurred. + * @return true if the endpoint should be closed. */ - public void onError(Throwable t) + public boolean onError(Throwable t) { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { - if (_sessionState != State.CLOSED || _closeStatus == null) + if (_webSocketState != WebSocketState.CLOSED || _closeStatus == null) throw new IllegalArgumentException(); // Override any normal close status. @@ -151,34 +159,109 @@ public void onError(Throwable t) // Otherwise set the error if it wasn't already set to notify onError as well as onClose. if (_closeStatus.getCause() == null) _closeStatus = new CloseStatus(_closeStatus.getCode(), _closeStatus.getReason(), t); + + return lockedForceCloseEndpointState(); } } - public boolean onEof() + public Result onClosed(CloseStatus closeStatus) { - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { - switch (_sessionState) + boolean closeEndpoint = lockedForceCloseEndpointState(); + boolean notifyWebSocketClose = false; + if (_webSocketState != WebSocketState.CLOSED) { - case CLOSED: - case ISHUT: - return false; + _closeStatus = closeStatus; + _webSocketState = WebSocketState.CLOSED; + notifyWebSocketClose = true; + } - default: + return new Result(notifyWebSocketClose, closeEndpoint); + } + } + + /** + * Handle an EOF from the transport. + * @return a pair of booleans; + * The first indicates whether the websocket listeners should be notified of close. + * The second indicates whether the underlying endpoint should be closed. + */ + public EofResult onEof() + { + try (AutoLock l = _lock.lock()) + { + return switch (_webSocketState) + { + case CLOSED -> + { + boolean closeEndpoint = lockedForceCloseEndpointState(); + yield new EofResult(false, closeEndpoint, false); + } + case ISHUT -> + { + boolean closeEndpoint = false; + boolean shutdownOutput = false; + switch (_endPointState) + { + case OPEN -> _endPointState = EndPointState.ISHUT; + case CLOSED, ISHUT -> + { /* NOOP */ } + case OSHUT -> + { + // If this was a client it didn't shut down output when it sent the close frame because of RFC6455 7.1.1. + // So we should do the shutdown output before closing the endpoint. + shutdownOutput = _behavior == Behavior.CLIENT; + closeEndpoint = true; + _endPointState = EndPointState.CLOSED; + } + default -> throw new IllegalStateException(_endPointState.toString()); + } + yield new EofResult(false, closeEndpoint, shutdownOutput); + } + default -> + { if (_closeStatus == null || CloseStatus.isOrdinary(_closeStatus.getCode())) _closeStatus = new CloseStatus(CloseStatus.NO_CLOSE, "Session Closed", new ClosedChannelException()); - _sessionState = State.CLOSED; - return true; - } + _webSocketState = WebSocketState.CLOSED; + + boolean closeEndpoint = lockedForceCloseEndpointState(); + yield new EofResult(true, closeEndpoint, false); + } + }; } } - public boolean onOutgoingFrame(Frame frame) throws Exception + public CloseResult onCloseFrameSent() + { + try (AutoLock l = _lock.lock()) + { + return switch (_endPointState) + { + case OPEN -> + { + _endPointState = EndPointState.OSHUT; + // We only shut down output if we are a server because of RFC6455 7.1.1. + // When the client receives an EOF it will shut down its output. + yield new CloseResult(_behavior == Behavior.SERVER, false); + } + case ISHUT -> + { + // We have already read EOF so we can shut down output even if we're a client. + _endPointState = EndPointState.CLOSED; + yield new CloseResult(true, true); + } + case OSHUT, CLOSED -> new CloseResult(false, false); + }; + } + } + + public Result onOutgoingFrame(Frame frame) throws Exception { byte opcode = frame.getOpCode(); boolean fin = frame.isFin(); - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { if (!isOutputOpen()) throw new ClosedChannelException(); @@ -188,24 +271,25 @@ public boolean onOutgoingFrame(Frame frame) throws Exception _closeStatus = CloseStatus.getCloseStatus(frame); if (_closeStatus.isAbnormal()) { - _sessionState = State.CLOSED; - return true; + boolean closeEndpoint = lockedForceCloseEndpointState(); + _webSocketState = WebSocketState.CLOSED; + return new Result(true, closeEndpoint); } - switch (_sessionState) + return switch (_webSocketState) { - case CONNECTED: - case OPEN: - _sessionState = State.OSHUT; - return false; - - case ISHUT: - _sessionState = State.CLOSED; - return true; - - default: - throw new IllegalStateException(_sessionState.toString()); - } + case CONNECTED, OPEN -> + { + _webSocketState = WebSocketState.OSHUT; + yield new Result(false, false); + } + case ISHUT -> + { + _webSocketState = WebSocketState.CLOSED; + yield new Result(true, false); + } + default -> throw new IllegalStateException(_webSocketState.toString()); + }; } else if (frame.isDataFrame()) { @@ -213,15 +297,15 @@ else if (frame.isDataFrame()) } } - return false; + return new Result(false, false); } - public boolean onIncomingFrame(Frame frame) throws ProtocolException, ClosedChannelException + public Result onIncomingFrame(Frame frame) throws ProtocolException, ClosedChannelException { byte opcode = frame.getOpCode(); boolean fin = frame.isFin(); - try (AutoLock l = lock.lock()) + try (AutoLock l = _lock.lock()) { if (!isInputOpen()) throw new ClosedChannelException(); @@ -230,16 +314,19 @@ public boolean onIncomingFrame(Frame frame) throws ProtocolException, ClosedChan { _closeStatus = CloseStatus.getCloseStatus(frame); - switch (_sessionState) + switch (_webSocketState) { case OPEN: - _sessionState = State.ISHUT; - return false; + _webSocketState = WebSocketState.ISHUT; + return new Result(false, false); case OSHUT: - _sessionState = State.CLOSED; - return true; + // If we received abnormal status close, and we cannot send a response because we are OSHUT, + // so we should close underlying the connection. + boolean closeEndpoint = _closeStatus.isAbnormal() && lockedForceCloseEndpointState(); + _webSocketState = WebSocketState.CLOSED; + return new Result(true, closeEndpoint); default: - throw new IllegalStateException(_sessionState.toString()); + throw new IllegalStateException(_webSocketState.toString()); } } else if (frame.isDataFrame()) @@ -248,19 +335,32 @@ else if (frame.isDataFrame()) } } - return false; + return new Result(false, false); } @Override public String toString() { return String.format("%s@%x{%s,i=%s,o=%s,c=%s}", TypeUtil.toShortName(getClass()), hashCode(), - _sessionState, + _webSocketState, OpCode.name(_incomingContinuation), OpCode.name(_outgoingContinuation), _closeStatus); } + private boolean lockedForceCloseEndpointState() + { + assert _lock.isHeldByCurrentThread(); + + boolean closeEndpoint = false; + if (_endPointState != EndPointState.CLOSED) + { + _endPointState = EndPointState.CLOSED; + closeEndpoint = true; + } + return closeEndpoint; + } + private static byte checkDataSequence(byte opcode, boolean fin, byte lastOpCode) throws ProtocolException { switch (opcode) diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java index 0c890ba4faa7..09a0b6b3db5d 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-common/src/main/java/org/eclipse/jetty/websocket/common/JettyWebSocketFrameHandler.java @@ -18,7 +18,6 @@ import java.lang.invoke.MethodType; import java.lang.reflect.InvocationTargetException; import java.nio.ByteBuffer; -import java.nio.channels.ClosedChannelException; import java.util.concurrent.atomic.AtomicBoolean; import org.eclipse.jetty.util.BufferUtil; @@ -270,7 +269,7 @@ private void notifyOnClose(CloseStatus closeStatus, Callback callback) // Make sure onClose is only notified once. if (!closeNotified.compareAndSet(false, true)) { - callback.failed(new ClosedChannelException()); + callback.succeeded(); return; } diff --git a/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/TestListenerEndpoint.java b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/TestListenerEndpoint.java new file mode 100644 index 000000000000..ca241660b0df --- /dev/null +++ b/jetty-core/jetty-websocket/jetty-websocket-jetty-tests/src/test/java/org/eclipse/jetty/websocket/tests/TestListenerEndpoint.java @@ -0,0 +1,113 @@ +// +// ======================================================================== +// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.websocket.tests; + +import java.nio.ByteBuffer; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; + +import org.eclipse.jetty.util.BlockingArrayQueue; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.TypeUtil; +import org.eclipse.jetty.websocket.api.Callback; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.StatusCode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TestListenerEndpoint implements Session.Listener.AutoDemanding +{ + private static final Logger LOG = LoggerFactory.getLogger(TestListenerEndpoint.class); + + public final BlockingQueue