Skip to content

Commit

Permalink
Merged security fix to admin conn (#688)
Browse files Browse the repository at this point in the history
* Update workspace related documentation (#684)

* Update workspace related documentation

* Add more details to server/client workspace and add reference

* Update documentation format (#685)

* Cherry-picked security fix for admin conn to 2.1 branch

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
nvidianz and YuanTingHsieh authored Jun 18, 2022
1 parent 81dd280 commit 2b6e1b1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
10 changes: 7 additions & 3 deletions nvflare/fuel/hci/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@

LINE_END = "\x03" # Indicates the end of the line (end of text)
ALL_END = "\x04" # Marks the end of a complete transmission (End of Transmission)


MAX_MSG_SIZE = 1024
MAX_DATA_SIZE = 512 * 1024 * 1024
MAX_IDLE_TIME = 10


def receive_til_end(sock, end=ALL_END):
total_data = []

data_size = 0
sock.settimeout(MAX_IDLE_TIME)
while True:
data = str(sock.recv(1024), "utf-8")
data_size += len(data)
if data_size > MAX_DATA_SIZE:
raise BufferError(f"Data size exceeds limit ({MAX_DATA_SIZE} bytes)")
if end in data:
total_data.append(data[: data.find(end)])
break
Expand Down
29 changes: 25 additions & 4 deletions nvflare/fuel/hci/server/hci.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import socketserver
import ssl
import threading
import traceback

from nvflare.fuel.hci.conn import Connection, receive_til_end
from nvflare.fuel.hci.proto import validate_proto
from nvflare.fuel.hci.security import get_certificate_common_name

from .reg import ServerCommandRegister

MAX_ADMIN_CONNECTIONS = 16


class _MsgHandler(socketserver.BaseRequestHandler):
"""Message handler.
Expand All @@ -32,8 +33,23 @@ class _MsgHandler(socketserver.BaseRequestHandler):
ServerCommandRegister.
"""

connections = 0
lock = threading.Lock()

def __init__(self, request, client_address, server):
# handle() is called in the constructor so logger must be initialized first
self.logger = logging.getLogger(self.__class__.__name__)
super().__init__(request, client_address, server)

def handle(self):
try:
with _MsgHandler.lock:
_MsgHandler.connections += 1

self.logger.debug(f"Concurrent admin connections: {_MsgHandler.connections}")
if _MsgHandler.connections > MAX_ADMIN_CONNECTIONS:
raise ConnectionRefusedError(f"Admin connection limit ({MAX_ADMIN_CONNECTIONS}) reached")

conn = Connection(self.request, self.server)

if self.server.use_ssl:
Expand Down Expand Up @@ -68,8 +84,13 @@ def handle(self):

if not conn.ended:
conn.close()
except BaseException:
traceback.print_exc()
except BaseException as exc:
self.logger.error(f"Admin connection terminated due to exception: {str(exc)}")
if self.logger.getEffectiveLevel() <= logging.DEBUG:
self.logger.exception("Admin connection error")
finally:
with _MsgHandler.lock:
_MsgHandler.connections -= 1


def initialize_hci():
Expand Down Expand Up @@ -121,7 +142,7 @@ def __init__(
ctx.load_verify_locations(ca_cert)
ctx.load_cert_chain(certfile=server_cert, keyfile=server_key)

# replace the socket with an ssl version of itself
# replace the socket with an SSL version of itself
self.socket = ctx.wrap_socket(self.socket, server_side=True)
self.use_ssl = True

Expand Down

0 comments on commit 2b6e1b1

Please sign in to comment.