From 2994afc55927f38461ea52fac8c487f126cb18e5 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Mon, 26 Jun 2023 15:18:53 -0400 Subject: [PATCH] Merged custom driver to branch 2.3 (#1827) --- nvflare/fuel/f3/comm_config.py | 24 ++++---- nvflare/fuel/f3/communicator.py | 47 ++++++++++++---- nvflare/fuel/f3/drivers/driver_manager.py | 55 ++++++++++++------- nvflare/fuel/utils/config_service.py | 3 +- .../data/custom_drivers/com/__init__.py | 13 +++++ .../custom_drivers/com/example/__init__.py | 13 +++++ .../custom_drivers/com/example/warp_driver.py | 46 ++++++++++++++++ .../custom_drivers/config/comm_config.json | 3 + .../fuel/f3/drivers/custom_driver_test.py | 36 ++++++++++++ .../fuel/f3/drivers/driver_manager_test.py | 2 +- 10 files changed, 200 insertions(+), 42 deletions(-) create mode 100644 tests/unit_test/data/custom_drivers/com/__init__.py create mode 100644 tests/unit_test/data/custom_drivers/com/example/__init__.py create mode 100644 tests/unit_test/data/custom_drivers/com/example/warp_driver.py create mode 100644 tests/unit_test/data/custom_drivers/config/comm_config.json create mode 100644 tests/unit_test/fuel/f3/drivers/custom_driver_test.py diff --git a/nvflare/fuel/f3/comm_config.py b/nvflare/fuel/f3/comm_config.py index ed4b75cc6d..c30670176c 100644 --- a/nvflare/fuel/f3/comm_config.py +++ b/nvflare/fuel/f3/comm_config.py @@ -25,13 +25,14 @@ class VarName: - MAX_MSG_SIZE = "max_message_size" - ALLOW_ADHOC_CONNECTIONS = "allow_adhoc_conns" - ADHOC_CONNECTION_SCHEME = "adhoc_conn_scheme" - INTERNAL_CONNECTION_SCHEME = "internal_conn_scheme" - BACKBONE_CONNECTION_GENERATION = "backbone_conn_gen" + MAX_MESSAGE_SIZE = "max_message_size" + ALLOW_ADHOC_CONNS = "allow_adhoc_conns" + ADHOC_CONN_SCHEME = "adhoc_conn_scheme" + INTERNAL_CONN_SCHEME = "internal_conn_scheme" + BACKBONE_CONN_GEN = "backbone_conn_gen" SUBNET_HEARTBEAT_INTERVAL = "subnet_heartbeat_interval" SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold" + COMM_DRIVER_PATH = "comm_driver_path" class CommConfigurator: @@ -55,22 +56,25 @@ def get_config(self): return self.config def get_max_message_size(self): - return ConfigService.get_int_var(VarName.MAX_MSG_SIZE, self.config, default=DEFAULT_MAX_MSG_SIZE) + return ConfigService.get_int_var(VarName.MAX_MESSAGE_SIZE, self.config, default=DEFAULT_MAX_MSG_SIZE) def allow_adhoc_connections(self, default): - return ConfigService.get_bool_var(VarName.ALLOW_ADHOC_CONNECTIONS, self.config, default=default) + return ConfigService.get_bool_var(VarName.ALLOW_ADHOC_CONNS, self.config, default=default) def get_adhoc_connection_scheme(self, default): - return ConfigService.get_str_var(VarName.ADHOC_CONNECTION_SCHEME, self.config, default=default) + return ConfigService.get_str_var(VarName.ADHOC_CONN_SCHEME, self.config, default=default) def get_internal_connection_scheme(self, default): - return ConfigService.get_str_var(VarName.INTERNAL_CONNECTION_SCHEME, self.config, default=default) + return ConfigService.get_str_var(VarName.INTERNAL_CONN_SCHEME, self.config, default=default) def get_backbone_connection_generation(self, default): - return ConfigService.get_int_var(VarName.BACKBONE_CONNECTION_GENERATION, self.config, default=default) + return ConfigService.get_int_var(VarName.BACKBONE_CONN_GEN, self.config, default=default) def get_subnet_heartbeat_interval(self, default): return ConfigService.get_int_var(VarName.SUBNET_HEARTBEAT_INTERVAL, self.config, default) def get_subnet_trouble_threshold(self, default): return ConfigService.get_int_var(VarName.SUBNET_TROUBLE_THRESHOLD, self.config, default) + + def get_comm_driver_path(self, default): + return ConfigService.get_str_var(VarName.COMM_DRIVER_PATH, self.config, default=default) diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index df71d6979a..04d7dc7217 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -18,6 +18,7 @@ from typing import Optional from nvflare.fuel.f3 import drivers +from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.drivers.driver import Driver from nvflare.fuel.f3.drivers.driver_manager import DriverManager @@ -30,24 +31,31 @@ log = logging.getLogger(__name__) _running_instances = weakref.WeakSet() +driver_mgr = DriverManager() +driver_loaded = False -def _exit_func(): - for c in _running_instances: - c.stop() - log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit") +def load_comm_drivers(): + global driver_loaded + # Load all the drivers in the drivers module + driver_mgr.search_folder(os.path.dirname(drivers.__file__), drivers.__package__) -atexit.register(_exit_func) + # Load custom drivers + driver_path = CommConfigurator().get_comm_driver_path(None) + if not driver_path: + return + + for path in driver_path.split(os.pathsep): + log.debug(f"Custom driver folder {path} is searched") + driver_mgr.search_folder(path, None) + + driver_loaded = True class Communicator: """FCI (Flare Communication Interface) main communication API""" - driver_mgr = DriverManager() - # Load all the drivers in the drivers module - driver_mgr.register_folder(os.path.dirname(drivers.__file__), drivers.__package__) - def __init__(self, local_endpoint: Endpoint): self.local_endpoint = local_endpoint self.monitors = [] @@ -161,7 +169,10 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str: CommError: If any errors """ - driver_class = self.driver_mgr.find_driver_class(url) + if not driver_loaded: + load_comm_drivers() + + driver_class = driver_mgr.find_driver_class(url) if not driver_class: raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") @@ -182,7 +193,10 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str): CommError: If any errors like invalid host or port not available """ - driver_class = self.driver_mgr.find_driver_class(scheme) + if not driver_loaded: + load_comm_drivers() + + driver_class = driver_mgr.find_driver_class(scheme) if not driver_class: raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}") @@ -243,3 +257,14 @@ def remove_connector(self, handle: str): CommError: If any errors """ self.conn_manager.remove_connector(handle) + + +def _exit_func(): + while _running_instances: + c = next(iter(_running_instances)) + # This call will remove the entry from the set + c.stop() + log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit") + + +atexit.register(_exit_func) diff --git a/nvflare/fuel/f3/drivers/driver_manager.py b/nvflare/fuel/f3/drivers/driver_manager.py index e950049708..29862d0366 100644 --- a/nvflare/fuel/f3/drivers/driver_manager.py +++ b/nvflare/fuel/f3/drivers/driver_manager.py @@ -15,6 +15,7 @@ import inspect import logging import os +import sys from typing import Optional, Type from nvflare.fuel.f3.comm_error import CommError @@ -27,8 +28,8 @@ class DriverManager: """Transport driver manager""" def __init__(self): - # scheme-< self.drivers = {} + self.class_cache = set() def register(self, driver_class: Type[Driver]): """Register a driver with Driver Manager @@ -51,29 +52,45 @@ def register(self, driver_class: Type[Driver]): self.drivers[key] = driver_class log.debug(f"Driver {driver_class.__name__} is registered for {scheme}") - def register_folder(self, folder: str, package: str): - """Scan the folder and register all drivers + def search_folder(self, folder: str, package: Optional[str]): + """Search the folder recursively and register all drivers Args: folder: The folder to scan - package: The root package for all the drivers + package: The root package for all the drivers. If none, the folder is the + root of the packages """ - class_cache = set() - - for file_name in os.listdir(folder): - if file_name != "__init__.py" and file_name[-3:] == ".py": - module = package + "." + file_name[:-3] - imported = importlib.import_module(module) - for _, cls_obj in inspect.getmembers(imported, inspect.isclass): - if cls_obj.__name__ in class_cache: - continue - class_cache.add(cls_obj.__name__) - - spec = inspect.getfullargspec(cls_obj.__init__) - # classes who are abstract or take extra args in __init__ can't be auto-registered - if issubclass(cls_obj, Driver) and not inspect.isabstract(cls_obj) and len(spec.args) == 1: - self.register(cls_obj) + if package is None and folder not in sys.path: + sys.path.append(folder) + + for root, dirs, files in os.walk(folder): + for filename in files: + if filename.endswith(".py"): + module = filename[:-3] + sub_folder = root[len(folder) :] + if sub_folder: + sub_folder = sub_folder.strip("/").replace("/", ".") + + if sub_folder: + module = sub_folder + "." + module + + if package: + module = package + "." + module + + imported = importlib.import_module(module) + for _, cls_obj in inspect.getmembers(imported, inspect.isclass): + if cls_obj.__name__ in self.class_cache: + continue + self.class_cache.add(cls_obj.__name__) + + if issubclass(cls_obj, Driver) and not inspect.isabstract(cls_obj): + spec = inspect.getfullargspec(cls_obj.__init__) + if len(spec.args) == 1: + self.register(cls_obj) + else: + # Can't handle argument in constructor + log.warning(f"Invalid driver, __init__ with extra arguments: {module}") def find_driver_class(self, scheme_or_url: str) -> Optional[Type[Driver]]: """Find the driver class based on scheme or URL diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index 465ad4bf28..f3585b9d82 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -216,7 +216,8 @@ def _any_var(cls, func, name, conf, default): if name in cls._var_values: return cls._var_values.get(name) v = func(name, conf, default) - cls._var_values[name] = v + if v is not None: + cls._var_values[name] = v return v @classmethod diff --git a/tests/unit_test/data/custom_drivers/com/__init__.py b/tests/unit_test/data/custom_drivers/com/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/tests/unit_test/data/custom_drivers/com/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/data/custom_drivers/com/example/__init__.py b/tests/unit_test/data/custom_drivers/com/example/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/tests/unit_test/data/custom_drivers/com/example/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/data/custom_drivers/com/example/warp_driver.py b/tests/unit_test/data/custom_drivers/com/example/warp_driver.py new file mode 100644 index 0000000000..67fb2afd7c --- /dev/null +++ b/tests/unit_test/data/custom_drivers/com/example/warp_driver.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List + +from nvflare.fuel.f3.drivers.base_driver import BaseDriver +from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo +from nvflare.fuel.f3.drivers.driver_params import DriverCap + + +class WarpDriver(BaseDriver): + """A dummy driver to test custom driver loading""" + + def __init__(self): + super().__init__() + + @staticmethod + def supported_transports() -> List[str]: + return ["warp"] + + @staticmethod + def capabilities() -> Dict[str, Any]: + return {DriverCap.HEARTBEAT.value: False, DriverCap.SUPPORT_SSL.value: False} + + def listen(self, connector: ConnectorInfo): + self.connector = connector + + def connect(self, connector: ConnectorInfo): + self.connector = connector + + def shutdown(self): + self.close_all() + + @staticmethod + def get_urls(scheme: str, resources: dict) -> (str, str): + return "warp:enterprise" diff --git a/tests/unit_test/data/custom_drivers/config/comm_config.json b/tests/unit_test/data/custom_drivers/config/comm_config.json new file mode 100644 index 0000000000..86bf7ac441 --- /dev/null +++ b/tests/unit_test/data/custom_drivers/config/comm_config.json @@ -0,0 +1,3 @@ +{ + "comm_driver_path": "./tests/unit_test/data/custom_drivers" +} \ No newline at end of file diff --git a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py new file mode 100644 index 0000000000..5137f52db5 --- /dev/null +++ b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest + +from nvflare.fuel.f3 import communicator + +# Setup custom driver path before communicator module initialization +from nvflare.fuel.utils.config_service import ConfigService + + +class TestCustomDriver: + @pytest.fixture + def manager(self): + rel_path = "../../../data/custom_drivers/config" + config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), rel_path)) + ConfigService.initialize({}, [config_path]) + communicator.load_comm_drivers() + + return communicator.driver_mgr + + def test_custom_driver_loading(self, manager): + driver_class = manager.find_driver_class("warp") + assert driver_class.__name__ == "WarpDriver" diff --git a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py index f346775aaf..a653a47af6 100644 --- a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py +++ b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py @@ -27,7 +27,7 @@ class TestDriverManager: @pytest.fixture def manager(self): driver_manager = DriverManager() - driver_manager.register_folder(os.path.dirname(drivers.__file__), drivers.__package__) + driver_manager.search_folder(os.path.dirname(drivers.__file__), drivers.__package__) return driver_manager @pytest.mark.parametrize(