Skip to content

Commit

Permalink
Merged custom driver to branch 2.3 (#1827)
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz authored Jun 26, 2023
1 parent e330ec6 commit 2994afc
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 42 deletions.
24 changes: 14 additions & 10 deletions nvflare/fuel/f3/comm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@

class VarName:

MAX_MSG_SIZE = "max_message_size"
ALLOW_ADHOC_CONNECTIONS = "allow_adhoc_conns"
ADHOC_CONNECTION_SCHEME = "adhoc_conn_scheme"
INTERNAL_CONNECTION_SCHEME = "internal_conn_scheme"
BACKBONE_CONNECTION_GENERATION = "backbone_conn_gen"
MAX_MESSAGE_SIZE = "max_message_size"
ALLOW_ADHOC_CONNS = "allow_adhoc_conns"
ADHOC_CONN_SCHEME = "adhoc_conn_scheme"
INTERNAL_CONN_SCHEME = "internal_conn_scheme"
BACKBONE_CONN_GEN = "backbone_conn_gen"
SUBNET_HEARTBEAT_INTERVAL = "subnet_heartbeat_interval"
SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold"
COMM_DRIVER_PATH = "comm_driver_path"


class CommConfigurator:
Expand All @@ -55,22 +56,25 @@ def get_config(self):
return self.config

def get_max_message_size(self):
return ConfigService.get_int_var(VarName.MAX_MSG_SIZE, self.config, default=DEFAULT_MAX_MSG_SIZE)
return ConfigService.get_int_var(VarName.MAX_MESSAGE_SIZE, self.config, default=DEFAULT_MAX_MSG_SIZE)

def allow_adhoc_connections(self, default):
return ConfigService.get_bool_var(VarName.ALLOW_ADHOC_CONNECTIONS, self.config, default=default)
return ConfigService.get_bool_var(VarName.ALLOW_ADHOC_CONNS, self.config, default=default)

def get_adhoc_connection_scheme(self, default):
return ConfigService.get_str_var(VarName.ADHOC_CONNECTION_SCHEME, self.config, default=default)
return ConfigService.get_str_var(VarName.ADHOC_CONN_SCHEME, self.config, default=default)

def get_internal_connection_scheme(self, default):
return ConfigService.get_str_var(VarName.INTERNAL_CONNECTION_SCHEME, self.config, default=default)
return ConfigService.get_str_var(VarName.INTERNAL_CONN_SCHEME, self.config, default=default)

def get_backbone_connection_generation(self, default):
return ConfigService.get_int_var(VarName.BACKBONE_CONNECTION_GENERATION, self.config, default=default)
return ConfigService.get_int_var(VarName.BACKBONE_CONN_GEN, self.config, default=default)

def get_subnet_heartbeat_interval(self, default):
return ConfigService.get_int_var(VarName.SUBNET_HEARTBEAT_INTERVAL, self.config, default)

def get_subnet_trouble_threshold(self, default):
return ConfigService.get_int_var(VarName.SUBNET_TROUBLE_THRESHOLD, self.config, default)

def get_comm_driver_path(self, default):
return ConfigService.get_str_var(VarName.COMM_DRIVER_PATH, self.config, default=default)
47 changes: 36 additions & 11 deletions nvflare/fuel/f3/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional

from nvflare.fuel.f3 import drivers
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.drivers.driver import Driver
from nvflare.fuel.f3.drivers.driver_manager import DriverManager
Expand All @@ -30,24 +31,31 @@

log = logging.getLogger(__name__)
_running_instances = weakref.WeakSet()
driver_mgr = DriverManager()
driver_loaded = False


def _exit_func():
for c in _running_instances:
c.stop()
log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit")
def load_comm_drivers():
global driver_loaded

# Load all the drivers in the drivers module
driver_mgr.search_folder(os.path.dirname(drivers.__file__), drivers.__package__)

atexit.register(_exit_func)
# Load custom drivers
driver_path = CommConfigurator().get_comm_driver_path(None)
if not driver_path:
return

for path in driver_path.split(os.pathsep):
log.debug(f"Custom driver folder {path} is searched")
driver_mgr.search_folder(path, None)

driver_loaded = True


class Communicator:
"""FCI (Flare Communication Interface) main communication API"""

driver_mgr = DriverManager()
# Load all the drivers in the drivers module
driver_mgr.register_folder(os.path.dirname(drivers.__file__), drivers.__package__)

def __init__(self, local_endpoint: Endpoint):
self.local_endpoint = local_endpoint
self.monitors = []
Expand Down Expand Up @@ -161,7 +169,10 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> str:
CommError: If any errors
"""

driver_class = self.driver_mgr.find_driver_class(url)
if not driver_loaded:
load_comm_drivers()

driver_class = driver_mgr.find_driver_class(url)
if not driver_class:
raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}")

Expand All @@ -182,7 +193,10 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str):
CommError: If any errors like invalid host or port not available
"""

driver_class = self.driver_mgr.find_driver_class(scheme)
if not driver_loaded:
load_comm_drivers()

driver_class = driver_mgr.find_driver_class(scheme)
if not driver_class:
raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}")

Expand Down Expand Up @@ -243,3 +257,14 @@ def remove_connector(self, handle: str):
CommError: If any errors
"""
self.conn_manager.remove_connector(handle)


def _exit_func():
while _running_instances:
c = next(iter(_running_instances))
# This call will remove the entry from the set
c.stop()
log.debug(f"Communicator {c.local_endpoint.name} was left running, stopped on exit")


atexit.register(_exit_func)
55 changes: 36 additions & 19 deletions nvflare/fuel/f3/drivers/driver_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import logging
import os
import sys
from typing import Optional, Type

from nvflare.fuel.f3.comm_error import CommError
Expand All @@ -27,8 +28,8 @@ class DriverManager:
"""Transport driver manager"""

def __init__(self):
# scheme-<
self.drivers = {}
self.class_cache = set()

def register(self, driver_class: Type[Driver]):
"""Register a driver with Driver Manager
Expand All @@ -51,29 +52,45 @@ def register(self, driver_class: Type[Driver]):
self.drivers[key] = driver_class
log.debug(f"Driver {driver_class.__name__} is registered for {scheme}")

def register_folder(self, folder: str, package: str):
"""Scan the folder and register all drivers
def search_folder(self, folder: str, package: Optional[str]):
"""Search the folder recursively and register all drivers
Args:
folder: The folder to scan
package: The root package for all the drivers
package: The root package for all the drivers. If none, the folder is the
root of the packages
"""

class_cache = set()

for file_name in os.listdir(folder):
if file_name != "__init__.py" and file_name[-3:] == ".py":
module = package + "." + file_name[:-3]
imported = importlib.import_module(module)
for _, cls_obj in inspect.getmembers(imported, inspect.isclass):
if cls_obj.__name__ in class_cache:
continue
class_cache.add(cls_obj.__name__)

spec = inspect.getfullargspec(cls_obj.__init__)
# classes who are abstract or take extra args in __init__ can't be auto-registered
if issubclass(cls_obj, Driver) and not inspect.isabstract(cls_obj) and len(spec.args) == 1:
self.register(cls_obj)
if package is None and folder not in sys.path:
sys.path.append(folder)

for root, dirs, files in os.walk(folder):
for filename in files:
if filename.endswith(".py"):
module = filename[:-3]
sub_folder = root[len(folder) :]
if sub_folder:
sub_folder = sub_folder.strip("/").replace("/", ".")

if sub_folder:
module = sub_folder + "." + module

if package:
module = package + "." + module

imported = importlib.import_module(module)
for _, cls_obj in inspect.getmembers(imported, inspect.isclass):
if cls_obj.__name__ in self.class_cache:
continue
self.class_cache.add(cls_obj.__name__)

if issubclass(cls_obj, Driver) and not inspect.isabstract(cls_obj):
spec = inspect.getfullargspec(cls_obj.__init__)
if len(spec.args) == 1:
self.register(cls_obj)
else:
# Can't handle argument in constructor
log.warning(f"Invalid driver, __init__ with extra arguments: {module}")

def find_driver_class(self, scheme_or_url: str) -> Optional[Type[Driver]]:
"""Find the driver class based on scheme or URL
Expand Down
3 changes: 2 additions & 1 deletion nvflare/fuel/utils/config_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def _any_var(cls, func, name, conf, default):
if name in cls._var_values:
return cls._var_values.get(name)
v = func(name, conf, default)
cls._var_values[name] = v
if v is not None:
cls._var_values[name] = v
return v

@classmethod
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_test/data/custom_drivers/com/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions tests/unit_test/data/custom_drivers/com/example/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
46 changes: 46 additions & 0 deletions tests/unit_test/data/custom_drivers/com/example/warp_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List

from nvflare.fuel.f3.drivers.base_driver import BaseDriver
from nvflare.fuel.f3.drivers.connector_info import ConnectorInfo
from nvflare.fuel.f3.drivers.driver_params import DriverCap


class WarpDriver(BaseDriver):
"""A dummy driver to test custom driver loading"""

def __init__(self):
super().__init__()

@staticmethod
def supported_transports() -> List[str]:
return ["warp"]

@staticmethod
def capabilities() -> Dict[str, Any]:
return {DriverCap.HEARTBEAT.value: False, DriverCap.SUPPORT_SSL.value: False}

def listen(self, connector: ConnectorInfo):
self.connector = connector

def connect(self, connector: ConnectorInfo):
self.connector = connector

def shutdown(self):
self.close_all()

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
return "warp:enterprise"
3 changes: 3 additions & 0 deletions tests/unit_test/data/custom_drivers/config/comm_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"comm_driver_path": "./tests/unit_test/data/custom_drivers"
}
36 changes: 36 additions & 0 deletions tests/unit_test/fuel/f3/drivers/custom_driver_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest

from nvflare.fuel.f3 import communicator

# Setup custom driver path before communicator module initialization
from nvflare.fuel.utils.config_service import ConfigService


class TestCustomDriver:
@pytest.fixture
def manager(self):
rel_path = "../../../data/custom_drivers/config"
config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), rel_path))
ConfigService.initialize({}, [config_path])
communicator.load_comm_drivers()

return communicator.driver_mgr

def test_custom_driver_loading(self, manager):
driver_class = manager.find_driver_class("warp")
assert driver_class.__name__ == "WarpDriver"
2 changes: 1 addition & 1 deletion tests/unit_test/fuel/f3/drivers/driver_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TestDriverManager:
@pytest.fixture
def manager(self):
driver_manager = DriverManager()
driver_manager.register_folder(os.path.dirname(drivers.__file__), drivers.__package__)
driver_manager.search_folder(os.path.dirname(drivers.__file__), drivers.__package__)
return driver_manager

@pytest.mark.parametrize(
Expand Down

0 comments on commit 2994afc

Please sign in to comment.