Skip to content

Commit

Permalink
Fix #258 by adding 3 measures against DoS attacks (#295)
Browse files Browse the repository at this point in the history
1. Message size is limited to 0.5 GB
2. Connection is timed out after idling for 10 seconds
3. Number of admin connections is limited to 128
  • Loading branch information
nvidianz authored Mar 12, 2022
1 parent 52fa8fc commit 782362b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
10 changes: 7 additions & 3 deletions nvflare/fuel/hci/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@

LINE_END = "\x03" # Indicates the end of the line (end of text)
ALL_END = "\x04" # Marks the end of a complete transmission (End of Transmission)


MAX_MSG_SIZE = 1024
MAX_DATA_SIZE = 512 * 1024 * 1024
MAX_IDLE_TIME = 10


def receive_til_end(sock, end=ALL_END):
total_data = []

data_size = 0
sock.settimeout(MAX_IDLE_TIME)
while True:
data = str(sock.recv(1024), "utf-8")
data_size += len(data)
if data_size > MAX_DATA_SIZE:
raise BufferError(f"Data size exceeds limit ({MAX_DATA_SIZE} bytes)")
if end in data:
total_data.append(data[: data.find(end)])
break
Expand Down
30 changes: 25 additions & 5 deletions nvflare/fuel/hci/server/hci.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import socketserver
import ssl
import threading
import traceback

from nvflare.fuel.hci.conn import Connection, receive_til_end
from nvflare.fuel.hci.proto import validate_proto
from nvflare.fuel.hci.security import get_certificate_common_name

from .reg import ServerCommandRegister

MAX_ADMIN_CONNECTIONS = 128


class _MsgHandler(socketserver.BaseRequestHandler):
"""Message handler.
Expand All @@ -32,8 +32,23 @@ class _MsgHandler(socketserver.BaseRequestHandler):
ServerCommandRegister.
"""

connections = 0
lock = threading.Lock()

def __init__(self, request, client_address, server):
# handle() is called in the constructor so logger must be initialized first
self.logger = logging.getLogger(self.__class__.__name__)
super().__init__(request, client_address, server)

def handle(self):
try:
with _MsgHandler.lock:
_MsgHandler.connections += 1

self.logger.debug(f"Concurrent admin connections: {_MsgHandler.connections}")
if _MsgHandler.connections > MAX_ADMIN_CONNECTIONS:
raise ConnectionRefusedError(f"Admin connection limit ({MAX_ADMIN_CONNECTIONS}) reached")

conn = Connection(self.request, self.server)

if self.server.use_ssl:
Expand Down Expand Up @@ -68,8 +83,13 @@ def handle(self):

if not conn.ended:
conn.close()
except BaseException:
traceback.print_exc()
except BaseException as exc:
self.logger.error(f"Admin connection terminated due to exception: {str(exc)}")
if self.logger.getEffectiveLevel() <= logging.DEBUG:
self.logger.exception("Admin connection error")
finally:
with _MsgHandler.lock:
_MsgHandler.connections -= 1


def initialize_hci():
Expand Down Expand Up @@ -121,7 +141,7 @@ def __init__(
ctx.load_verify_locations(ca_cert)
ctx.load_cert_chain(certfile=server_cert, keyfile=server_key)

# replace the socket with an ssl version of itself
# replace the socket with an SSL version of itself
self.socket = ctx.wrap_socket(self.socket, server_side=True)
self.use_ssl = True

Expand Down

0 comments on commit 782362b

Please sign in to comment.