diff --git a/src/WebsocketManager/ConnectionManager.cs b/src/WebsocketManager/ConnectionManager.cs index d268e13..07256d7 100644 --- a/src/WebsocketManager/ConnectionManager.cs +++ b/src/WebsocketManager/ConnectionManager.cs @@ -25,19 +25,30 @@ public string GetId(WebSocket socket) { return _sockets.FirstOrDefault(p => p.Value == socket).Key; } + public void AddSocket(WebSocket socket) { _sockets.TryAdd(CreateConnectionId(), socket); } - public async Task RemoveSocket(string id) + public async Task CloseAndRemoveSocket(string id) { - WebSocket socket; - _sockets.TryRemove(id, out socket); + if (_sockets.TryRemove(id, out var socket)) + { + await socket.CloseOutputAsync(closeStatus: WebSocketCloseStatus.NormalClosure, + statusDescription: "Closed by the ConnectionManager", + cancellationToken: CancellationToken.None); + } + } + + public Task RemoveSocket(string id) + { + if (_sockets.TryRemove(id, out var socket)) + { + socket.Dispose(); + } - await socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure, - statusDescription: "Closed by the ConnectionManager", - cancellationToken: CancellationToken.None); + return Task.CompletedTask; } private string CreateConnectionId() diff --git a/src/WebsocketManager/Handler.cs b/src/WebsocketManager/Handler.cs index 971e31d..20b0943 100644 --- a/src/WebsocketManager/Handler.cs +++ b/src/WebsocketManager/Handler.cs @@ -20,6 +20,11 @@ public virtual async Task OnConnected(WebSocket socket) WebSocketConnectionManager.AddSocket(socket); } + public virtual async Task OnCloseConnection(WebSocket socket) + { + await WebSocketConnectionManager.CloseAndRemoveSocket(WebSocketConnectionManager.GetId(socket)); + } + public virtual async Task OnDisconnected(WebSocket socket) { await WebSocketConnectionManager.RemoveSocket(WebSocketConnectionManager.GetId(socket)); diff --git a/src/WebsocketManager/Middleware.cs b/src/WebsocketManager/Middleware.cs index ba35f17..a607cbc 100644 --- a/src/WebsocketManager/Middleware.cs +++ b/src/WebsocketManager/Middleware.cs @@ -20,30 +20,40 @@ public WebSocketManagerMiddleware(RequestDelegate next, public async Task Invoke(HttpContext context) { - if(!context.WebSockets.IsWebSocketRequest) - return; - - var socket = await context.WebSockets.AcceptWebSocketAsync(); - await _webSocketHandler.OnConnected(socket); - - await Receive(socket, async(result, buffer) => + if (context.WebSockets.IsWebSocketRequest) { - if(result.MessageType == WebSocketMessageType.Text) - { - await _webSocketHandler.ReceiveAsync(socket, result, buffer); - return; - } + var socket = await context.WebSockets.AcceptWebSocketAsync(); + await _webSocketHandler.OnConnected(socket); - else if(result.MessageType == WebSocketMessageType.Close) + try { - await _webSocketHandler.OnDisconnected(socket); - return; + await Receive(socket, async(result, buffer) => + { + if (result.MessageType == WebSocketMessageType.Text) + { + await _webSocketHandler.ReceiveAsync(socket, result, buffer); + } + else if(result.MessageType == WebSocketMessageType.Close) + { + await _webSocketHandler.OnCloseConnection(socket); + } + }); } + catch (WebSocketException e) + { + if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + await _webSocketHandler.OnDisconnected(socket); + return; + } - }); - - //TODO - investigate the Kestrel exception thrown when this is the last middleware - //await _next.Invoke(context); + throw; + } + } + else + { + await _next.Invoke(context); + } } private async Task Receive(WebSocket socket, Action handleMessage)