diff --git a/nvflare/fuel/f3/comm_config.py b/nvflare/fuel/f3/comm_config.py index c30670176c..77189c07ad 100644 --- a/nvflare/fuel/f3/comm_config.py +++ b/nvflare/fuel/f3/comm_config.py @@ -33,24 +33,42 @@ class VarName: SUBNET_HEARTBEAT_INTERVAL = "subnet_heartbeat_interval" SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold" COMM_DRIVER_PATH = "comm_driver_path" + USE_AIO_GRPC_VAR_NAME = "use_aio_grpc" class CommConfigurator: + _config_loaded = False + _configuration = None + def __init__(self): + # only load once! self.logger = logging.getLogger(self.__class__.__name__) - config = None - for file_name in _comm_config_files: - try: - config = ConfigService.load_json(file_name) - if config: - break - except FileNotFoundError: - self.logger.debug(f"config file {file_name} not found from config path") - config = None - except Exception as ex: - self.logger.error(f"failed to load config file {file_name}: {secure_format_exception(ex)}") - config = None - self.config = config + if not CommConfigurator._config_loaded: + config = None + for file_name in _comm_config_files: + try: + config = ConfigService.load_json(file_name) + if config: + break + except FileNotFoundError: + self.logger.debug(f"config file {file_name} not found from config path") + config = None + except Exception as ex: + self.logger.error(f"failed to load config file {file_name}: {secure_format_exception(ex)}") + config = None + + CommConfigurator._configuration = config + CommConfigurator._config_loaded = True + self.config = CommConfigurator._configuration + + @staticmethod + def reset(): + """Reset the configurator to allow reloading config files. + + Returns: + + """ + CommConfigurator._config_loaded = False def get_config(self): return self.config @@ -78,3 +96,18 @@ def get_subnet_trouble_threshold(self, default): def get_comm_driver_path(self, default): return ConfigService.get_str_var(VarName.COMM_DRIVER_PATH, self.config, default=default) + + def use_aio_grpc(self, default): + return ConfigService.get_bool_var(VarName.USE_AIO_GRPC_VAR_NAME, self.config, default) + + def get_int_var(self, name: str, default=None): + return ConfigService.get_int_var(name, self.config, default=default) + + def get_float_var(self, name: str, default=None): + return ConfigService.get_float_var(name, self.config, default=default) + + def get_bool_var(self, name: str, default=None): + return ConfigService.get_bool_var(name, self.config, default=default) + + def get_str_var(self, name: str, default=None): + return ConfigService.get_str_var(name, self.config, default=default) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 12671e74eb..74c894be13 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +import random import threading import time from typing import Any, Dict, List @@ -34,6 +35,7 @@ from .base_driver import BaseDriver from .driver_params import DriverCap, DriverParams from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required GRPC_DEFAULT_OPTIONS = [ @@ -65,6 +67,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di self.channel = channel # for client side self.lock = threading.Lock() + conf = CommConfigurator() + if conf.get_bool_var("simulate_unstable_network", default=False): + self.disconn = threading.Thread(target=self._disconnect, daemon=True) + self.disconn.start() + + def _disconnect(self): + t = random.randint(10, 60) + self.logger.info(f"will close connection after {t} secs") + time.sleep(t) + self.logger.info(f"close connection now after {t} secs") + self.close() + def get_conn_properties(self) -> dict: return self.conn_props @@ -101,18 +115,18 @@ async def read_loop(self, msg_iter): except grpc.aio.AioRpcError as error: if not self.closing: if error.code() == grpc.StatusCode.CANCELLED: - self.logger.debug(f"Connection {self} is closed by peer") + self.logger.info(f"Connection {self} is closed by peer") else: - self.logger.debug(f"Connection {self} Error: {error.details()}") + self.logger.info(f"Connection {self} Error: {error.details()}") self.logger.debug(secure_format_traceback()) else: - self.logger.debug(f"Connection {self} is closed locally") + self.logger.info(f"Connection {self} is closed locally") except Exception as ex: if not self.closing: - self.logger.debug(f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}") + self.logger.info(f"{self}: exception {type(ex)} in read_loop: {secure_format_exception(ex)}") self.logger.debug(secure_format_traceback()) - self.logger.debug(f"{self}: in {ct.name}: done read_loop") + self.logger.info(f"{self}: in {ct.name}: done read_loop") async def generate_output(self): ct = threading.current_thread() @@ -123,11 +137,10 @@ async def generate_output(self): yield item except Exception as ex: if self.closing: - self.logger.debug(f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}") + self.logger.info(f"{self}: connection closed by {type(ex)}: {secure_format_exception(ex)}") else: - self.logger.debug(f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}") + self.logger.info(f"{self}: generate_output exception {type(ex)}: {secure_format_exception(ex)}") self.logger.debug(secure_format_traceback()) - self.logger.debug(f"{self}: done generate_output") @@ -137,20 +150,10 @@ def __init__(self, server, aio_ctx: AioContext): self.aio_ctx = aio_ctx self.logger = get_logger(self) - async def _write_loop(self, connection, grpc_context): - self.logger.debug("started _write_loop") - try: - while True: - f = await connection.oq.get() - await grpc_context.write(f) - except Exception as ex: - self.logger.debug(f"_write_loop except: {type(ex)}: {secure_format_exception(ex)}") - self.logger.debug("finished _write_loop") - async def Stream(self, request_iterator, context): connection = None - ct = threading.current_thread() try: + ct = threading.current_thread() self.logger.debug(f"SERVER started Stream CB in thread {ct.name}") conn_props = { DriverParams.PEER_ADDR.value: context.peer(), @@ -169,23 +172,22 @@ async def Stream(self, request_iterator, context): ) self.logger.debug(f"SERVER created connection in thread {ct.name}") self.server.driver.add_connection(connection) - try: - await asyncio.gather(self._write_loop(connection, context), connection.read_loop(request_iterator)) - except asyncio.CancelledError: - self.logger.debug("SERVER: RPC cancelled") - except Exception as ex: - self.logger.debug(f"await gather except: {type(ex)}: {secure_format_exception(ex)}") - self.logger.debug(f"SERVER: done await gather in thread {ct.name}") - + self.aio_ctx.run_coro(connection.read_loop(request_iterator)) + while True: + item = await connection.oq.get() + yield item + except asyncio.CancelledError: + self.logger.info("SERVER: RPC cancelled") except Exception as ex: - self.logger.debug(f"Connection closed due to error: {secure_format_exception(ex)}") + if connection: + self.logger.info(f"{connection}: connection exception: {secure_format_exception(ex)}") + self.logger.debug(secure_format_traceback()) finally: if connection: - with connection.lock: - connection.context = None - self.logger.debug(f"SERVER: closing connection {connection.name}") + connection.close() + self.logger.info(f"SERVER: closed connection {connection.name}") self.server.driver.close_connection(connection) - self.logger.debug(f"SERVER: cleanly finished Stream CB in thread {ct.name}") + self.logger.info("SERVER: finished Stream CB") class Server: @@ -207,10 +209,12 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C secure = ssl_required(params) if secure: - credentials = AioGrpcDriver.get_grpc_server_credentials(params) + credentials = get_grpc_server_credentials(params) self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") else: self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") except Exception as ex: conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" self.logger.debug(conn_ctx.error) @@ -251,7 +255,10 @@ def __init__(self): @staticmethod def supported_transports() -> List[str]: - return ["grpc", "grpcs"] + if use_aio_grpc(): + return ["grpc", "grpcs"] + else: + return ["agrpc", "agrpcs"] @staticmethod def capabilities() -> Dict[str, Any]: @@ -280,9 +287,9 @@ def listen(self, connector: ConnectorInfo): time.sleep(0.1) if conn_ctx.error: raise CommError(code=CommError.ERROR, message=conn_ctx.error) - self.logger.debug("SERVER: waiting for server to finish") + self.logger.info(f"SERVER: listening on {connector}") conn_ctx.waiter.wait() - self.logger.debug("SERVER: server is done") + self.logger.info(f"SERVER: server is done listening on {connector}") async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, conn_ctx: _ConnCtx): self.logger.debug("Started _start_connect coro") @@ -295,10 +302,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co secure = ssl_required(params) if secure: grpc_channel = grpc.aio.secure_channel( - address, options=self.options, credentials=self.get_grpc_client_credentials(params) + address, options=self.options, credentials=get_grpc_client_credentials(params) ) + self.logger.info(f"created secure channel at {address}") else: grpc_channel = grpc.aio.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") async with grpc_channel as channel: self.logger.debug(f"CLIENT: connected to {address}") @@ -358,6 +367,7 @@ def connect(self, connector: ConnectorInfo): self.add_connection(conn_ctx.conn) conn_ctx.waiter.wait() self.close_connection(conn_ctx.conn) + self.logger.info(f"CLIENT: connection {conn_ctx.conn} closed") def shutdown(self): if self.closing: @@ -374,38 +384,9 @@ def shutdown(self): def get_urls(scheme: str, resources: dict) -> (str, str): secure = resources.get(DriverParams.SECURE) if secure: - scheme = "grpcs" + if use_aio_grpc(): + scheme = "grpcs" + else: + scheme = "agrpcs" return get_tcp_urls(scheme, resources) - - @staticmethod - def get_grpc_client_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY)) - - return grpc.ssl_channel_credentials( - certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert - ) - - @staticmethod - def get_grpc_server_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY)) - - return grpc.ssl_server_credentials( - [(private_key, cert_chain)], - root_certificates=root_cert, - require_client_auth=True, - ) - - @staticmethod - def read_file(file_name: str): - if not file_name: - return None - - with open(file_name, "rb") as f: - return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc/qq.py b/nvflare/fuel/f3/drivers/grpc/qq.py new file mode 100644 index 0000000000..ca0eeb25f2 --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/qq.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import queue + + +class QueueClosed(Exception): + pass + + +class QQ: + def __init__(self): + self.q = queue.Queue() + self.closed = False + self.logger = logging.getLogger(self.__class__.__name__) + + def close(self): + self.closed = True + + def append(self, i): + if self.closed: + raise QueueClosed("queue stopped") + self.q.put_nowait(i) + + def __iter__(self): + return self + + def __next__(self): + if self.closed: + raise StopIteration() + while True: + try: + return self.q.get(block=True, timeout=0.1) + except queue.Empty: + if self.closed: + self.logger.debug("Queue closed - stop iteration") + raise StopIteration() + except Exception as e: + self.logger.error(f"queue exception {type(e)}") + raise e diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py new file mode 100644 index 0000000000..d95bb8138a --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def use_aio_grpc(): + configurator = CommConfigurator() + return configurator.use_aio_grpc(default=True) + + +def get_grpc_client_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT)) + private_key = _read_file(params.get(DriverParams.CLIENT_KEY)) + return grpc.ssl_channel_credentials( + certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert + ) + + +def get_grpc_server_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.SERVER_CERT)) + private_key = _read_file(params.get(DriverParams.SERVER_KEY)) + + return grpc.ssl_server_credentials( + [(private_key, cert_chain)], + root_certificates=root_cert, + require_client_auth=True, + ) + + +def _read_file(file_name: str): + if not file_name: + return None + + with open(file_name, "rb") as f: + return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py new file mode 100644 index 0000000000..a29399c253 --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -0,0 +1,287 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from concurrent import futures +from typing import Any, Dict, List, Union + +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_error import CommError +from nvflare.fuel.f3.connection import Connection +from nvflare.fuel.f3.drivers.driver import ConnectorInfo +from nvflare.fuel.f3.drivers.grpc.streamer_pb2_grpc import ( + StreamerServicer, + StreamerStub, + add_StreamerServicer_to_server, +) +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.security.logging import secure_format_exception + +from .base_driver import BaseDriver +from .driver_params import DriverCap, DriverParams +from .grpc.qq import QQ +from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc +from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required + +GRPC_DEFAULT_OPTIONS = [ + ("grpc.max_send_message_length", MAX_FRAME_SIZE), + ("grpc.max_receive_message_length", MAX_FRAME_SIZE), +] + + +class StreamConnection(Connection): + + seq_num = 0 + + def __init__(self, oq: QQ, connector: ConnectorInfo, conn_props: dict, side: str, context=None, channel=None): + super().__init__(connector) + self.side = side + self.oq = oq + self.closing = False + self.conn_props = conn_props + self.context = context # for server side + self.channel = channel # for client side + self.lock = threading.Lock() + self.logger = get_logger(self) + + def get_conn_properties(self) -> dict: + return self.conn_props + + def close(self): + self.closing = True + with self.lock: + self.oq.close() + if self.context: + try: + self.context.abort(grpc.StatusCode.CANCELLED, "service closed") + except: + # ignore any exception when aborting + pass + self.context = None + if self.channel: + self.channel.close() + self.channel = None + + def send_frame(self, frame: Union[bytes, bytearray, memoryview]): + try: + StreamConnection.seq_num += 1 + seq = StreamConnection.seq_num + self.logger.debug(f"{self.side}: queued frame #{seq}") + self.oq.append(Frame(seq=seq, data=bytes(frame))) + except BaseException as ex: + raise CommError(CommError.ERROR, f"Error sending frame: {ex}") + + def read_loop(self, msg_iter): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: started read_loop in thread {ct.name}") + try: + for f in msg_iter: + if self.closing: + break + + assert isinstance(f, Frame) + self.logger.debug(f"{self.side} in {ct.name}: incoming frame #{f.seq}") + if self.frame_receiver: + self.frame_receiver.process_frame(f.data) + else: + self.logger.error(f"{self.side}: Frame receiver not registered for connection: {self.name}") + except Exception as ex: + if not self.closing: + self.logger.debug(f"{self.side}: exception {type(ex)} in read_loop") + if self.oq: + self.logger.debug(f"{self.side}: closing queue") + self.oq.close() + self.logger.debug(f"{self.side} in {ct.name}: done read_loop") + + def generate_output(self): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: generate_output in thread {ct.name}") + for i in self.oq: + assert isinstance(i, Frame) + self.logger.debug(f"{self.side}: outgoing frame #{i.seq}") + yield i + self.logger.debug(f"{self.side}: done generate_output in thread {ct.name}") + + +class Servicer(StreamerServicer): + def __init__(self, server): + self.server = server + self.logger = get_logger(self) + + def Stream(self, request_iterator, context): + connection = None + oq = QQ() + t = None + ct = threading.current_thread() + conn_props = { + DriverParams.PEER_ADDR.value: context.peer(), + DriverParams.LOCAL_ADDR.value: get_address(self.server.connector.params), + } + cn_names = context.auth_context().get("x509_common_name") + if cn_names: + conn_props[DriverParams.PEER_CN.value] = cn_names[0].decode("utf-8") + + try: + self.logger.debug(f"SERVER started Stream CB in thread {ct.name}") + connection = StreamConnection(oq, self.server.connector, conn_props, "SERVER", context=context) + self.logger.debug(f"SERVER created connection in thread {ct.name}") + self.server.driver.add_connection(connection) + self.logger.debug(f"SERVER created read_loop thread in thread {ct.name}") + t = threading.Thread(target=connection.read_loop, args=(request_iterator,)) + t.start() + yield from connection.generate_output() + except BaseException as ex: + self.logger.error(f"Connection closed due to error: {ex}") + finally: + if t is not None: + t.join() + if connection: + self.logger.debug(f"SERVER: closing connection {connection.name}") + self.server.driver.close_connection(connection) + self.logger.info("SERVER: finished Stream CB") + + +class Server: + def __init__( + self, + driver, + connector, + max_workers, + options, + ): + self.driver = driver + self.logger = get_logger(self) + self.connector = connector + self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=options) + servicer = Servicer(self) + add_StreamerServicer_to_server(servicer, self.grpc_server) + + params = connector.params + addr = get_address(params) + try: + self.logger.debug(f"SERVER: connector params: {params}") + secure = ssl_required(params) + if secure: + credentials = get_grpc_server_credentials(params) + self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") + else: + self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") + except Exception as ex: + error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" + self.logger.debug(error) + + def start(self): + self.grpc_server.start() + self.grpc_server.wait_for_termination() + + def shutdown(self): + self.grpc_server.stop(grace=0.5) + + +class GrpcDriver(BaseDriver): + def __init__(self): + BaseDriver.__init__(self) + self.server = None + self.closing = False + self.max_workers = 100 + self.options = GRPC_DEFAULT_OPTIONS + self.logger = get_logger(self) + configurator = CommConfigurator() + config = configurator.get_config() + if config: + my_params = config.get("grpc") + if my_params: + self.max_workers = my_params.get("max_workers", 100) + self.options = my_params.get("options") + self.logger.debug(f"GRPC Config: max_workers={self.max_workers}, options={self.options}") + + @staticmethod + def supported_transports() -> List[str]: + if use_aio_grpc(): + return ["nagrpc", "nagrpcs"] + else: + return ["grpc", "grpcs"] + + @staticmethod + def capabilities() -> Dict[str, Any]: + return {DriverCap.HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: True} + + def listen(self, connector: ConnectorInfo): + self.logger.info(f"starting grpc connector: {connector}") + self.connector = connector + self.server = Server(self, connector, max_workers=self.max_workers, options=self.options) + self.server.start() + + def connect(self, connector: ConnectorInfo): + self.logger.debug("CLIENT: trying connect ...") + params = connector.params + address = get_address(params) + conn_props = {DriverParams.PEER_ADDR.value: address} + + secure = ssl_required(params) + if secure: + self.logger.debug("CLIENT: creating secure channel") + channel = grpc.secure_channel( + address, options=self.options, credentials=get_grpc_client_credentials(params) + ) + self.logger.info(f"created secure channel at {address}") + else: + self.logger.info("CLIENT: creating insecure channel") + channel = grpc.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") + + with channel: + stub = StreamerStub(channel) + self.logger.debug("CLIENT: got stub") + oq = QQ() + connection = StreamConnection(oq, connector, conn_props, "CLIENT", channel=channel) + self.add_connection(connection) + self.logger.debug("CLIENT: added connection") + try: + received = stub.Stream(connection.generate_output()) + connection.read_loop(received) + + except BaseException as ex: + self.logger.info(f"CLIENT: connection done: {type(ex)}") + + with connection.lock: + # when we get here the channel is already closed + # set connection.channel to None to prevent closing channel again in connection.close(). + connection.channel = None + connection.close() + self.close_connection(connection) + self.logger.info(f"CLIENT: finished connection {connection}") + + @staticmethod + def get_urls(scheme: str, resources: dict) -> (str, str): + secure = resources.get(DriverParams.SECURE) + if secure: + if use_aio_grpc(): + scheme = "nagrpcs" + else: + scheme = "grpcs" + return get_tcp_urls(scheme, resources) + + def shutdown(self): + if self.closing: + return + self.closing = True + self.close_all() + if self.server: + self.server.shutdown() diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 377dfd783d..0b5520554c 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -86,7 +86,8 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: def get_address(params: dict) -> str: host = params.get(DriverParams.HOST.value, "0.0.0.0") port = params.get(DriverParams.PORT.value, 0) - + if not host: + host = "0.0.0.0" return f"{host}:{port}" diff --git a/nvflare/fuel/f3/qat/admin.py b/nvflare/fuel/f3/qat/admin.py index b334eb0dad..c48c6c7fc7 100644 --- a/nvflare/fuel/f3/qat/admin.py +++ b/nvflare/fuel/f3/qat/admin.py @@ -15,7 +15,7 @@ import argparse from nvflare.fuel.common.excepts import ConfigError -from nvflare.fuel.f3.qat2.net_config import NetConfig +from nvflare.fuel.f3.qat.net_config import NetConfig from nvflare.fuel.hci.client.cli import AdminClient, CredentialType from nvflare.fuel.hci.client.static_service_finder import StaticServiceFinder from nvflare.fuel.utils.config_service import ConfigService diff --git a/nvflare/fuel/f3/qat/run_cell.py b/nvflare/fuel/f3/qat/run_cell.py index bc2bd0e999..28b52eb1c8 100644 --- a/nvflare/fuel/f3/qat/run_cell.py +++ b/nvflare/fuel/f3/qat/run_cell.py @@ -16,7 +16,7 @@ import logging from nvflare.fuel.f3.mpm import MainProcessMonitor as Mpm -from nvflare.fuel.f3.qat2.cell_runner import CellRunner +from nvflare.fuel.f3.qat.cell_runner import CellRunner from nvflare.fuel.utils.config_service import ConfigService diff --git a/nvflare/fuel/f3/qat/run_server.py b/nvflare/fuel/f3/qat/run_server.py index c481d4fdc1..d0309f1710 100644 --- a/nvflare/fuel/f3/qat/run_server.py +++ b/nvflare/fuel/f3/qat/run_server.py @@ -16,7 +16,7 @@ import logging from nvflare.fuel.f3.mpm import MainProcessMonitor -from nvflare.fuel.f3.qat2.server import Server +from nvflare.fuel.f3.qat.server import Server from nvflare.fuel.utils.config_service import ConfigService diff --git a/nvflare/fuel/f3/sfm/conn_manager.py b/nvflare/fuel/f3/sfm/conn_manager.py index 8f812e18a2..caf0aa3451 100644 --- a/nvflare/fuel/f3/sfm/conn_manager.py +++ b/nvflare/fuel/f3/sfm/conn_manager.py @@ -37,7 +37,7 @@ FRAME_THREAD_POOL_SIZE = 100 CONN_THREAD_POOL_SIZE = 16 INIT_WAIT = 1 -MAX_WAIT = 60 +MAX_WAIT = 10 SILENT_RECONNECT_TIME = 5 SELF_ADDR = "0.0.0.0:0" diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index f3585b9d82..6884e23f4b 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -200,6 +200,8 @@ def _get_var(cls, name: str, conf): if isinstance(conf, dict): return conf.get(name) + else: + return None @classmethod def _int_var(cls, name: str, conf=None, default=None): diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 0a0f3ac704..dcbda9f68d 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -499,7 +499,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul ) def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: - self.log_info(fl_ctx, "received client job_heartbeat aux request") + self.log_debug(fl_ctx, "received client job_heartbeat") return make_reply(ReturnCode.OK) def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: diff --git a/tests/unit_test/fuel/f3/communicator_test.py b/tests/unit_test/fuel/f3/communicator_test.py index b4a3bd1355..e8cdbbe9d7 100644 --- a/tests/unit_test/fuel/f3/communicator_test.py +++ b/tests/unit_test/fuel/f3/communicator_test.py @@ -86,6 +86,7 @@ def comm_b(self): ("grpc", "3000-4000"), ("http", "4000-5000"), ("atcp", "5000-6000"), + ("nagrpc", "6000-7000"), ], ) def test_sfm_message(self, comm_a, comm_b, scheme, port_range): diff --git a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py index 5137f52db5..90583c653c 100644 --- a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py +++ b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py @@ -18,12 +18,14 @@ from nvflare.fuel.f3 import communicator # Setup custom driver path before communicator module initialization +from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.utils.config_service import ConfigService class TestCustomDriver: @pytest.fixture def manager(self): + CommConfigurator.reset() rel_path = "../../../data/custom_drivers/config" config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), rel_path)) ConfigService.initialize({}, [config_path]) diff --git a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py index a653a47af6..438ccea80b 100644 --- a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py +++ b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py @@ -20,6 +20,7 @@ from nvflare.fuel.f3.drivers.aio_http_driver import AioHttpDriver from nvflare.fuel.f3.drivers.aio_tcp_driver import AioTcpDriver from nvflare.fuel.f3.drivers.driver_manager import DriverManager +from nvflare.fuel.f3.drivers.grpc_driver import GrpcDriver from nvflare.fuel.f3.drivers.tcp_driver import TcpDriver @@ -37,6 +38,8 @@ def manager(self): ("stcp", TcpDriver), ("grpc", AioGrpcDriver), ("grpcs", AioGrpcDriver), + ("nagrpc", GrpcDriver), + ("nagrpcs", GrpcDriver), ("http", AioHttpDriver), ("https", AioHttpDriver), ("ws", AioHttpDriver), diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index 9984abdcba..a09ff7a4aa 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -26,7 +26,6 @@ from nvflare.lighter.impl.cert import serialize_cert from nvflare.lighter.utils import sign_folders, verify_folder_signature - folders = ["folder1", "folder2"] files = ["file1", "file2"]