diff --git a/nvflare/fuel/hci/conn.py b/nvflare/fuel/hci/conn.py index 3d35a77f62..5542641f22 100644 --- a/nvflare/fuel/hci/conn.py +++ b/nvflare/fuel/hci/conn.py @@ -27,16 +27,20 @@ LINE_END = "\x03" # Indicates the end of the line (end of text) ALL_END = "\x04" # Marks the end of a complete transmission (End of Transmission) - - MAX_MSG_SIZE = 1024 +MAX_DATA_SIZE = 512 * 1024 * 1024 +MAX_IDLE_TIME = 10 def receive_til_end(sock, end=ALL_END): total_data = [] - + data_size = 0 + sock.settimeout(MAX_IDLE_TIME) while True: data = str(sock.recv(1024), "utf-8") + data_size += len(data) + if data_size > MAX_DATA_SIZE: + raise BufferError(f"Data size exceeds limit ({MAX_DATA_SIZE} bytes)") if end in data: total_data.append(data[: data.find(end)]) break diff --git a/nvflare/fuel/hci/server/hci.py b/nvflare/fuel/hci/server/hci.py index 2d58b6cdfe..96e0d0cc73 100644 --- a/nvflare/fuel/hci/server/hci.py +++ b/nvflare/fuel/hci/server/hci.py @@ -16,7 +16,6 @@ import socketserver import ssl import threading -import traceback from nvflare.fuel.hci.conn import Connection, receive_til_end from nvflare.fuel.hci.proto import validate_proto @@ -24,6 +23,8 @@ from .reg import ServerCommandRegister +MAX_ADMIN_CONNECTIONS = 16 + class _MsgHandler(socketserver.BaseRequestHandler): """Message handler. @@ -32,8 +33,23 @@ class _MsgHandler(socketserver.BaseRequestHandler): ServerCommandRegister. """ + connections = 0 + lock = threading.Lock() + + def __init__(self, request, client_address, server): + # handle() is called in the constructor so logger must be initialized first + self.logger = logging.getLogger(self.__class__.__name__) + super().__init__(request, client_address, server) + def handle(self): try: + with _MsgHandler.lock: + _MsgHandler.connections += 1 + + self.logger.debug(f"Concurrent admin connections: {_MsgHandler.connections}") + if _MsgHandler.connections > MAX_ADMIN_CONNECTIONS: + raise ConnectionRefusedError(f"Admin connection limit ({MAX_ADMIN_CONNECTIONS}) reached") + conn = Connection(self.request, self.server) if self.server.use_ssl: @@ -68,8 +84,13 @@ def handle(self): if not conn.ended: conn.close() - except BaseException: - traceback.print_exc() + except BaseException as exc: + self.logger.error(f"Admin connection terminated due to exception: {str(exc)}") + if self.logger.getEffectiveLevel() <= logging.DEBUG: + self.logger.exception("Admin connection error") + finally: + with _MsgHandler.lock: + _MsgHandler.connections -= 1 def initialize_hci(): @@ -121,7 +142,7 @@ def __init__( ctx.load_verify_locations(ca_cert) ctx.load_cert_chain(certfile=server_cert, keyfile=server_key) - # replace the socket with an ssl version of itself + # replace the socket with an SSL version of itself self.socket = ctx.wrap_socket(self.socket, server_side=True) self.use_ssl = True