@@ -238,12 +238,20 @@ def __init__(
238238 socket : socket .socket ,
239239 handler : Callable [[socket .socket , Any ], None ],
240240 logger : LoggerLike | None = None ,
241+ * ,
242+ connections : set [ServerConnection ] | None = None ,
241243 ) -> None :
242244 self .socket = socket
243245 self .handler = handler
244246 if logger is None :
245247 logger = logging .getLogger ("websockets.server" )
246248 self .logger = logger
249+
250+ # _connections tracks active connections
251+ if connections is None :
252+ connections = set ()
253+ self ._connections = connections
254+
247255 if sys .platform != "win32" :
248256 self .shutdown_watcher , self .shutdown_notifier = os .pipe ()
249257
@@ -285,15 +293,36 @@ def serve_forever(self) -> None:
285293 thread = threading .Thread (target = self .handler , args = (sock , addr ))
286294 thread .start ()
287295
288- def shutdown (self ) -> None :
296+ def shutdown (
297+ self , * , code : CloseCode = CloseCode .NORMAL_CLOSURE , reason : str = ""
298+ ) -> None :
289299 """
290300 See :meth:`socketserver.BaseServer.shutdown`.
291301
302+ Shuts down the server and closes existing connections. Optional arguments
303+ ``code`` and ``reason`` can be used to provide additional information to
304+ the clients, e.g.,::
305+
306+ server.shutdown(reason="scheduled_maintenance")
307+
308+ Args:
309+ code: Closing code, defaults to ``CloseCode.NORMAL_CLOSURE``.
310+ reason: Closing reason, default to empty string.
311+
292312 """
293313 self .socket .close ()
294314 if sys .platform != "win32" :
295315 os .write (self .shutdown_notifier , b"x" )
296316
317+ # Close all connections
318+ conns = list (self ._connections )
319+ for conn in conns :
320+ try :
321+ conn .close (code = code , reason = reason )
322+ except Exception as exc :
323+ debug_msg = f"Could not close { conn .id } : { exc } "
324+ self .logger .debug (debug_msg , exc_info = exc )
325+
297326 def fileno (self ) -> int :
298327 """
299328 See :meth:`socketserver.BaseServer.fileno`.
@@ -516,6 +545,24 @@ def handler(websocket):
516545 do_handshake_on_connect = False ,
517546 )
518547
548+ # Stores active ServerConnection instances, used by the server to handle graceful
549+ # shutdown in Server.shutdown()
550+ connections : set [ServerConnection ] = set ()
551+
552+ def on_connection_created (connection : ServerConnection ) -> None :
553+ # Invoked from conn_handler() to add a new ServerConnection instance to
554+ # Server._connections
555+ connections .add (connection )
556+
557+ def on_connection_closed (connection : ServerConnection ) -> None :
558+ # Invoked from conn_handler() to remove a closed ServerConnection instance from
559+ # Server._connections. Keeping only active references in the set is important
560+ # for avoiding memory leaks.
561+ try :
562+ connections .remove (connection )
563+ except KeyError : # pragma: no cover
564+ pass
565+
519566 # Define request handler
520567
521568 def conn_handler (sock : socket .socket , addr : Any ) -> None :
@@ -581,6 +628,7 @@ def protocol_select_subprotocol(
581628 close_timeout = close_timeout ,
582629 max_queue = max_queue ,
583630 )
631+ on_connection_created (connection )
584632 except Exception :
585633 sock .close ()
586634 return
@@ -595,11 +643,13 @@ def protocol_select_subprotocol(
595643 )
596644 except TimeoutError :
597645 connection .close_socket ()
646+ on_connection_closed (connection )
598647 connection .recv_events_thread .join ()
599648 return
600649 except Exception :
601650 connection .logger .error ("opening handshake failed" , exc_info = True )
602651 connection .close_socket ()
652+ on_connection_closed (connection )
603653 connection .recv_events_thread .join ()
604654 return
605655
@@ -610,16 +660,18 @@ def protocol_select_subprotocol(
610660 except Exception :
611661 connection .logger .error ("connection handler failed" , exc_info = True )
612662 connection .close (CloseCode .INTERNAL_ERROR )
663+ on_connection_closed (connection )
613664 else :
614665 connection .close ()
666+ on_connection_closed (connection )
615667
616668 except Exception : # pragma: no cover
617669 # Don't leak sockets on unexpected errors.
618670 sock .close ()
619671
620672 # Initialize server
621673
622- return Server (sock , conn_handler , logger )
674+ return Server (sock , conn_handler , logger , connections = connections )
623675
624676
625677def unix_serve (
0 commit comments