Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class WebSocketCoreSession implements CoreSession, Dumpable

private final WebSocketComponents components;
private final Behavior behavior;
private final WebSocketSessionState sessionState = new WebSocketSessionState();
private final WebSocketSessionState sessionState;
private final FrameHandler handler;
private final Negotiated negotiated;
private final Flusher flusher = new Flusher(this);
Expand All @@ -76,6 +76,7 @@ public class WebSocketCoreSession implements CoreSession, Dumpable

public WebSocketCoreSession(FrameHandler handler, Behavior behavior, Negotiated negotiated, WebSocketComponents components)
{
this.sessionState = new WebSocketSessionState(behavior);
this.classLoader = Thread.currentThread().getContextClassLoader();
this.components = components;
this.handler = handler;
Expand Down Expand Up @@ -219,16 +220,20 @@ public void onEof()
if (LOG.isDebugEnabled())
LOG.debug("onEof() {}", this);

if (sessionState.onEof())
closeConnection(sessionState.getCloseStatus(), Callback.NOOP);
WebSocketSessionState.EofResult result = sessionState.onEof();
if (result.shutdownOutput())
getConnection().getEndPoint().shutdownOutput();
if (result.closeEndpoint())
abort();
if (result.notifyWebSocketClose())
notifyWebSocketConnectionClose(sessionState.getCloseStatus(), Callback.NOOP);
}

private void closeConnection(CloseStatus closeStatus, Callback callback)
private void notifyWebSocketConnectionClose(CloseStatus closeStatus, Callback callback)
{
if (LOG.isDebugEnabled())
LOG.debug("closeConnection() {} {}", closeStatus, this);

abort();
extensionStack.close();

// Forward Errors to Local WebSocket EndPoint
Expand Down Expand Up @@ -335,9 +340,13 @@ private void processError(CloseStatus closeStatus, Callback callback)
}
else
{
if (sessionState.onClosed(closeStatus))
WebSocketSessionState.Result result = sessionState.onClosed(closeStatus);
if (result.closeEndpoint())
abort();

if (result.notifyWebSocketClose())
{
closeConnection(closeStatus, callback);
notifyWebSocketConnectionClose(closeStatus, callback);
}
else
{
Expand Down Expand Up @@ -506,13 +515,39 @@ public void sendFrame(OutgoingEntry entry)
if (LOG.isDebugEnabled())
LOG.debug("sendFrame({}, {}, {})", frame, callback, batch);

boolean closeConnection = sessionState.onOutgoingFrame(frame);
if (closeConnection)
WebSocketSessionState.Result result = sessionState.onOutgoingFrame(frame);
callback = Callback.from(callback, failure ->
{
if (failure != null)
{
CloseStatus closeStatus = new CloseStatus(CloseStatus.NO_CLOSE, failure);
WebSocketSessionState.Result closeResult = sessionState.onClosed(closeStatus);
if (closeResult.closeEndpoint())
abort();
if (closeResult.notifyWebSocketClose())
notifyWebSocketConnectionClose(closeStatus, NOOP);
return;
}

if (frame.getOpCode() == OpCode.CLOSE)
{
WebSocketSessionState.CloseResult closeResult = sessionState.onCloseFrameSent();
if (closeResult.shutdownOutput())
connection.getEndPoint().shutdownOutput();
if (closeResult.closeEndpoint())
abort();
}

if (result.closeEndpoint())
abort();
});

if (result.notifyWebSocketClose())
{
Callback c = callback;
Callback closeConnectionCallback = Callback.from(
() -> closeConnection(sessionState.getCloseStatus(), c),
t -> closeConnection(sessionState.getCloseStatus(), Callback.from(c, t)));
() -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), c),
t -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), Callback.from(c, t)));

flusher.sendFrame(new OutgoingEntry.Builder(entry)
.callback(closeConnectionCallback)
Expand All @@ -533,8 +568,17 @@ public void sendFrame(OutgoingEntry entry)
if (frame.getOpCode() == OpCode.CLOSE)
{
CloseStatus closeStatus = CloseStatus.getCloseStatus(frame);
if (closeStatus.isAbnormal() && sessionState.onClosed(closeStatus))
closeConnection(closeStatus, Callback.from(callback, t));
if (closeStatus.isAbnormal())
{
WebSocketSessionState.Result result = sessionState.onClosed(closeStatus);
if (result.closeEndpoint())
abort();

if (result.notifyWebSocketClose())
notifyWebSocketConnectionClose(closeStatus, Callback.from(callback, t));
else
callback.failed(t);
}
else
callback.failed(t);
}
Expand Down Expand Up @@ -656,7 +700,9 @@ public void onFrame(Frame frame, Callback callback)
if (LOG.isDebugEnabled())
LOG.debug("receiveFrame({}, {}) - connectionState={}, handler={}", frame, callback, sessionState, handler);

boolean closeConnection = sessionState.onIncomingFrame(frame);
WebSocketSessionState.Result result = sessionState.onIncomingFrame(frame);
if (result.closeEndpoint())
abort();

// Handle inbound frame
if (frame.getOpCode() != OpCode.CLOSE)
Expand All @@ -665,14 +711,15 @@ public void onFrame(Frame frame, Callback callback)
return;
}

// Handle inbound CLOSE
// Cancel demand to read to EOF, as we cannot receive any more frames after the CLOSE Frame.
connection.cancelDemand();
if (closeConnection)
if (result.notifyWebSocketClose())
{
closeCallback = Callback.from(() -> closeConnection(sessionState.getCloseStatus(), callback), t ->
closeCallback = Callback.from(() -> notifyWebSocketConnectionClose(sessionState.getCloseStatus(), callback), t ->
{
sessionState.onError(t);
closeConnection(sessionState.getCloseStatus(), callback);
if (sessionState.onError(t))
abort();
notifyWebSocketConnectionClose(sessionState.getCloseStatus(), callback);
});
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,6 @@ protected void onSuccess()

for (FlusherEntry entry : _completedEntries)
{
if (entry.getFrame().getOpCode() == OpCode.CLOSE && _behavior == Behavior.SERVER)
_endPoint.shutdownOutput();
notifyCallbackSuccess(entry.getCallback());
}
_completedEntries.clear();
Expand Down
Loading