From 4946ac6eb992a2a71dcb377850aacdafc2df3bc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Fri, 8 Mar 2024 13:31:44 -0800 Subject: [PATCH] [2.4] Add xgboost example, unit tests, integration tests (#2392) --- .readthedocs.yaml | 2 +- build_doc.sh | 2 +- .../base/app/config/config_fed_client.json | 7 +- .../base/app/config/config_fed_server.json | 3 - .../base_v2/app/config/config_fed_client.json | 38 ++++++ .../base_v2/app/config/config_fed_server.json | 16 +++ .../base_v2/app/custom/higgs_data_loader.py | 77 ++++++++++++ .../histogram-based/jobs/base_v2/meta.json | 10 ++ .../advanced/xgboost/prepare_job_config.sh | 2 + .../app/config/config_fed_server.json | 9 +- .../app/config/config_fed_client.json | 1 - .../app/config/config_fed_server.json | 9 +- .../xgboost/utils/prepare_job_config.py | 24 ++-- .../adaptors/grpc_client_adaptor.py | 32 ++++- .../adaptors/grpc_server_adaptor.py | 34 ++++- .../histogram_based_v2/grpc/grpc_client.py | 3 +- .../histogram_based_v2/grpc/grpc_server.py | 3 +- setup.cfg | 1 + .../standalone_job/hello_numpy_examples.yml | 16 +++ .../standalone_job/xgb_histogram_examples.yml | 95 ++++++++++++++ .../standalone_job/xgb_tree_examples.yml | 91 ++++++++++++++ .../integration_test/run_integration_tests.sh | 16 ++- tests/integration_test/src/example.py | 32 ++++- .../integration_test/src/mock_xgb/__init__.py | 13 ++ .../src/mock_xgb/aggr_servicer.py | 112 +++++++++++++++++ .../src/mock_xgb/mock_client_runner.py | 112 +++++++++++++++++ .../src/mock_xgb/mock_controller.py | 62 +++++++++ .../src/mock_xgb/mock_executor.py | 45 +++++++ .../src/mock_xgb/mock_server_runner.py | 48 +++++++ .../src/mock_xgb/run_client.py | 96 ++++++++++++++ .../src/mock_xgb/run_server.py | 44 +++++++ tests/integration_test/src/utils.py | 118 ++++++++++-------- .../src/validators/np_sag_result_validator.py | 5 +- tests/integration_test/test_configs.yml | 3 + tests/integration_test/xgb_test.py | 68 ++++++++++ tests/unit_test/app_opt/xgboost/__init__.py | 13 ++ .../xgboost/histrogram_based_v2/__init__.py | 13 ++ .../histrogram_based_v2/adaptor_test.py | 76 +++++++++++ .../histrogram_based_v2/adaptors/__init__.py | 13 ++ .../adaptors/grpc_client_adaptor_test.py | 39 ++++++ .../adaptors/grpc_server_adaptor_test.py | 36 ++++++ .../adaptors/mock_runner.py | 45 +++++++ 42 files changed, 1388 insertions(+), 96 deletions(-) create mode 100755 examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_client.json create mode 100755 examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_server.json create mode 100644 examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py create mode 100644 examples/advanced/xgboost/histogram-based/jobs/base_v2/meta.json create mode 100644 tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml create mode 100644 tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml create mode 100644 tests/integration_test/src/mock_xgb/__init__.py create mode 100644 tests/integration_test/src/mock_xgb/aggr_servicer.py create mode 100644 tests/integration_test/src/mock_xgb/mock_client_runner.py create mode 100644 tests/integration_test/src/mock_xgb/mock_controller.py create mode 100644 tests/integration_test/src/mock_xgb/mock_executor.py create mode 100644 tests/integration_test/src/mock_xgb/mock_server_runner.py create mode 100644 tests/integration_test/src/mock_xgb/run_client.py create mode 100644 tests/integration_test/src/mock_xgb/run_server.py create mode 100644 tests/integration_test/xgb_test.py create mode 100644 tests/unit_test/app_opt/xgboost/__init__.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/__init__.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptor_test.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/__init__.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_client_adaptor_test.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_server_adaptor_test.py create mode 100644 tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/mock_runner.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 41450770c5..3a23c1993a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -26,6 +26,6 @@ sphinx: python: install: - method: pip - path: .[doc] + path: .[dev] # system_packages: true diff --git a/build_doc.sh b/build_doc.sh index e91c1a5331..384b6d1fa6 100755 --- a/build_doc.sh +++ b/build_doc.sh @@ -49,7 +49,7 @@ function clean_docs() { } function build_html_docs() { - pip install -e .[doc] + pip install -e .[dev] sphinx-apidoc --module-first -f -o docs/apidocs/ nvflare "*poc" "*private" sphinx-build -b html docs docs/_build } diff --git a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json index c456968c25..c142310a3f 100755 --- a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json +++ b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json @@ -1,5 +1,6 @@ { "format_version": 2, + "num_rounds": 100, "executors": [ { "tasks": [ @@ -7,17 +8,17 @@ ], "executor": { "id": "Executor", - "name": "FedXGBHistogramExecutor", + "path": "nvflare.app_opt.xgboost.histogram_based.executor.FedXGBHistogramExecutor", "args": { "data_loader_id": "dataloader", - "num_rounds": 100, + "num_rounds": "{num_rounds}", "early_stopping_rounds": 2, "xgb_params": { "max_depth": 8, "eta": 0.1, "objective": "binary:logistic", "eval_metric": "auc", - "tree_method": "gpu_hist", + "tree_method": "hist", "nthread": 16 } } diff --git a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json index 036047b501..c759f3a703 100755 --- a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json +++ b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_server.json @@ -1,8 +1,5 @@ { "format_version": 2, - "server": { - "heart_beat_timeout": 600 - }, "task_data_filters": [], "task_result_filters": [], "components": [], diff --git a/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_client.json b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_client.json new file mode 100755 index 0000000000..337ddd84ca --- /dev/null +++ b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_client.json @@ -0,0 +1,38 @@ +{ + "format_version": 2, + "num_rounds": 100, + "executors": [ + { + "tasks": [ + "config", "start" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.executor.FedXGBHistogramExecutor", + "args": { + "data_loader_id": "dataloader", + "early_stopping_rounds": 2, + "xgb_params": { + "max_depth": 8, + "eta": 0.1, + "objective": "binary:logistic", + "eval_metric": "auc", + "tree_method": "hist", + "nthread": 16 + } + } + } + } + ], + "task_result_filters": [], + "task_data_filters": [], + "components": [ + { + "id": "dataloader", + "path": "higgs_data_loader.HIGGSDataLoader", + "args": { + "data_split_filename": "data_split.json" + } + } + ] +} diff --git a/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_server.json b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_server.json new file mode 100755 index 0000000000..7f92707d78 --- /dev/null +++ b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/config/config_fed_server.json @@ -0,0 +1,16 @@ +{ + "format_version": 2, + "num_rounds": 100, + "task_data_filters": [], + "task_result_filters": [], + "components": [], + "workflows": [ + { + "id": "xgb_controller", + "path": "nvflare.app_opt.xgboost.histogram_based_v2.controller.XGBFedController", + "args": { + "num_rounds": "{num_rounds}" + } + } + ] +} \ No newline at end of file diff --git a/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py new file mode 100644 index 0000000000..3725dc5830 --- /dev/null +++ b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pandas as pd +import xgboost as xgb + +from nvflare.app_opt.xgboost.data_loader import XGBDataLoader + + +def _read_higgs_with_pandas(data_path, start: int, end: int): + data_size = end - start + data = pd.read_csv(data_path, header=None, skiprows=start, nrows=data_size) + data_num = data.shape[0] + + # split to feature and label + x = data.iloc[:, 1:].copy() + y = data.iloc[:, 0].copy() + + return x, y, data_num + + +class HIGGSDataLoader(XGBDataLoader): + def __init__(self, data_split_filename): + """Reads HIGGS dataset and return XGB data matrix. + + Args: + data_split_filename: file name to data splits + """ + self.data_split_filename = data_split_filename + + def load_data(self, client_id: str): + with open(self.data_split_filename, "r") as file: + data_split = json.load(file) + + data_path = data_split["data_path"] + data_index = data_split["data_index"] + + # check if site_id and "valid" in the mapping dict + if client_id not in data_index.keys(): + raise ValueError( + f"Data does not contain Client {client_id} split", + ) + + if "valid" not in data_index.keys(): + raise ValueError( + "Data does not contain Validation split", + ) + + site_index = data_index[client_id] + valid_index = data_index["valid"] + + # training + x_train, y_train, total_train_data_num = _read_higgs_with_pandas( + data_path=data_path, start=site_index["start"], end=site_index["end"] + ) + dmat_train = xgb.DMatrix(x_train, label=y_train) + + # validation + x_valid, y_valid, total_valid_data_num = _read_higgs_with_pandas( + data_path=data_path, start=valid_index["start"], end=valid_index["end"] + ) + dmat_valid = xgb.DMatrix(x_valid, label=y_valid) + + return dmat_train, dmat_valid diff --git a/examples/advanced/xgboost/histogram-based/jobs/base_v2/meta.json b/examples/advanced/xgboost/histogram-based/jobs/base_v2/meta.json new file mode 100644 index 0000000000..6d82211a16 --- /dev/null +++ b/examples/advanced/xgboost/histogram-based/jobs/base_v2/meta.json @@ -0,0 +1,10 @@ +{ + "name": "xgboost_histogram_based_v2", + "resource_spec": {}, + "deploy_map": { + "app": [ + "@ALL" + ] + }, + "min_clients": 2 +} diff --git a/examples/advanced/xgboost/prepare_job_config.sh b/examples/advanced/xgboost/prepare_job_config.sh index 6af3117a52..e04e6e589c 100755 --- a/examples/advanced/xgboost/prepare_job_config.sh +++ b/examples/advanced/xgboost/prepare_job_config.sh @@ -22,4 +22,6 @@ prepare_job_config 20 cyclic uniform uniform $TREE_METHOD prepare_job_config 2 histogram uniform uniform $TREE_METHOD prepare_job_config 5 histogram uniform uniform $TREE_METHOD +prepare_job_config 2 histogram_v2 uniform uniform $TREE_METHOD +prepare_job_config 5 histogram_v2 uniform uniform $TREE_METHOD echo "Job configs generated" diff --git a/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json b/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json index 124a2296f9..1eda9e9ad1 100755 --- a/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json +++ b/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json @@ -1,11 +1,6 @@ { "format_version": 2, - - "server": { - "heart_beat_timeout": 600, - "task_request_interval": 0.05 - }, - + "num_rounds": 101, "task_data_filters": [], "task_result_filters": [], @@ -34,7 +29,7 @@ "name": "ScatterAndGather", "args": { "min_clients": 5, - "num_rounds": 101, + "num_rounds": "{num_rounds}", "start_round": 0, "wait_time_after_min_received": 0, "aggregator_id": "aggregator", diff --git a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json index 319c26de1a..6b25f996bb 100755 --- a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json +++ b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json @@ -13,7 +13,6 @@ "data_loader_id": "dataloader", "training_mode": "cyclic", "num_client_bagging": 1, - "lr_mode": "scaled", "local_model_path": "model.json", "global_model_path": "model_global.json", "learning_rate": 0.1, diff --git a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json index 686042e1b5..f2826926df 100755 --- a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json +++ b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json @@ -1,11 +1,6 @@ { "format_version": 2, - - "server": { - "heart_beat_timeout": 600, - "task_request_interval": 0.05 - }, - + "num_rounds": 20, "task_data_filters": [], "task_result_filters": [], @@ -29,7 +24,7 @@ "id": "cyclic_ctl", "name": "CyclicController", "args": { - "num_rounds": 20, + "num_rounds": "{num_rounds}", "task_assignment_timeout": 60, "persistor_id": "persistor", "shareable_generator_id": "shareable_generator", diff --git a/examples/advanced/xgboost/utils/prepare_job_config.py b/examples/advanced/xgboost/utils/prepare_job_config.py index f970faa016..33523081a4 100644 --- a/examples/advanced/xgboost/utils/prepare_job_config.py +++ b/examples/advanced/xgboost/utils/prepare_job_config.py @@ -20,8 +20,16 @@ from nvflare.apis.fl_constant import JobConstants +SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) +XGB_EXAMPLE_ROOT = SCRIPT_PATH.parent.parent.absolute() JOB_CONFIGS_ROOT = "jobs" -MODE_ALGO_MAP = {"bagging": "tree-based", "cyclic": "tree-based", "histogram": "histogram-based"} +MODE_ALGO_MAP = { + "bagging": "tree-based", + "cyclic": "tree-based", + "histogram": "histogram-based", + "histogram_v2": "histogram-based", +} +BASE_JOB_MAP = {"bagging": "bagging_base", "cyclic": "cyclic_base", "histogram": "base", "histogram_v2": "base_v2"} def job_config_args_parser(): @@ -79,12 +87,7 @@ def _get_data_split_name(args, site_name: str) -> str: def _get_src_job_dir(training_mode): - base_job_map = { - "bagging": "bagging_base", - "cyclic": "cyclic_base", - "histogram": "base", - } - return pathlib.Path(MODE_ALGO_MAP[training_mode]) / JOB_CONFIGS_ROOT / base_job_map[training_mode] + return XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[training_mode] / JOB_CONFIGS_ROOT / BASE_JOB_MAP[training_mode] def _gen_deploy_map(num_sites: int, site_name_prefix: str) -> dict: @@ -133,6 +136,7 @@ def _update_client_config(config: dict, args, lr_scale, site_name: str): num_client_bagging = args.site_num config["executors"][0]["executor"]["args"]["num_client_bagging"] = num_client_bagging else: + config["num_rounds"] = args.round_num config["components"][0]["args"]["data_split_filename"] = data_split_name config["executors"][0]["executor"]["args"]["xgb_params"]["nthread"] = args.nthread config["executors"][0]["executor"]["args"]["xgb_params"]["tree_method"] = args.tree_method @@ -140,10 +144,10 @@ def _update_client_config(config: dict, args, lr_scale, site_name: str): def _update_server_config(config: dict, args): if args.training_mode == "bagging": - config["workflows"][0]["args"]["num_rounds"] = args.round_num + 1 + config["num_rounds"] = args.round_num + 1 config["workflows"][0]["args"]["min_clients"] = args.site_num elif args.training_mode == "cyclic": - config["workflows"][0]["args"]["num_rounds"] = int(args.round_num / args.site_num) + config["num_rounds"] = int(args.round_num / args.site_num) def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name): @@ -198,7 +202,7 @@ def main(): src_job_path = _get_src_job_dir(args.training_mode) # create a new job - dst_job_path = pathlib.Path(MODE_ALGO_MAP[args.training_mode]) / JOB_CONFIGS_ROOT / job_name + dst_job_path = XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[args.training_mode] / JOB_CONFIGS_ROOT / job_name if not os.path.exists(dst_job_path): os.makedirs(dst_job_path) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py index 3b3d7a154d..5e046ec1c3 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py @@ -13,7 +13,6 @@ # limitations under the License. import multiprocessing -import sys import threading from typing import Tuple @@ -59,11 +58,39 @@ def start(self, ctx: dict): class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer): + """Implementation of XGBClientAdaptor that uses an internal `GrpcServer`. + + The `GrpcClientAdaptor` class serves as an interface between the XGBoost + federated client and federated server components. + It employs its `XGBRunner` to initiate an XGBoost federated gRPC client + and utilizes an internal `GrpcServer` to forward client requests/responses. + + The communication flow is as follows: + + 1. XGBoost federated gRPC client talks to `GrpcClientAdaptor`, which + encapsulates a `GrpcServer`. + Requests are then forwarded to `GrpcServerAdaptor`, which internally + manages a `GrpcClient` responsible for interacting with the XGBoost + federated gRPC server. + 2. XGBoost federated gRPC server talks to `GrpcServerAdaptor`, which + encapsulates a `GrpcClient`. + Responses are then forwarded to `GrpcClientAdaptor`, which internally + manages a `GrpcServer` responsible for interacting with the XGBoost + federated gRPC client. + """ + def __init__( self, int_server_grpc_options=None, in_process=False, ): + """Constructor method to initialize the object. + + Args: + int_server_grpc_options: An optional list of key-value pairs (`channel_arguments` + in gRPC Core runtime) to configure the gRPC channel of internal `GrpcServer`. + in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not. + """ XGBClientAdaptor.__init__(self) self.int_server_grpc_options = int_server_grpc_options self.in_process = in_process @@ -212,7 +239,6 @@ def _abort(self, reason: str): def Allgather(self, request: pb2.AllgatherRequest, context): try: - self.logger.info(f"Calling Allgather with {sys.getsizeof(request.send_buffer)}") rcv_buf = self._send_all_gather( rank=request.rank, seq=request.sequence_number, @@ -225,7 +251,6 @@ def Allgather(self, request: pb2.AllgatherRequest, context): def Allreduce(self, request: pb2.AllreduceRequest, context): try: - self.logger.info(f"Calling Allreduce with {sys.getsizeof(request.send_buffer)}") rcv_buf = self._send_all_reduce( rank=request.rank, seq=request.sequence_number, @@ -240,7 +265,6 @@ def Allreduce(self, request: pb2.AllreduceRequest, context): def Broadcast(self, request: pb2.BroadcastRequest, context): try: - self.logger.info(f"Calling Broadcast with {sys.getsizeof(request.send_buffer)}") rcv_buf = self._send_broadcast( rank=request.rank, seq=request.sequence_number, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_server_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_server_adaptor.py index e886faf020..1e5bf48507 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_server_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_server_adaptor.py @@ -26,12 +26,42 @@ class GrpcServerAdaptor(XGBServerAdaptor): + """Implementation of XGBServerAdaptor that uses an internal `GrpcClient`. + + The `GrpcServerAdaptor` class serves as an interface between the XGBoost + federated client and federated server components. + It employs its `XGBRunner` to initiate an XGBoost federated gRPC server + and utilizes an internal `GrpcClient` to forward client requests/responses. + + The communication flow is as follows: + + 1. XGBoost federated gRPC client talks to `GrpcClientAdaptor`, which + encapsulates a `GrpcServer`. + Requests are then forwarded to `GrpcServerAdaptor`, which internally + manages a `GrpcClient` responsible for interacting with the XGBoost + federated gRPC server. + 2. XGBoost federated gRPC server talks to `GrpcServerAdaptor`, which + encapsulates a `GrpcClient`. + Responses are then forwarded to `GrpcClientAdaptor`, which internally + manages a `GrpcServer` responsible for interacting with the XGBoost + federated gRPC client. + """ + def __init__( self, int_client_grpc_options=None, - xgb_server_ready_timeout=Constant.XGB_SERVER_READY_TIMEOUT, - in_process=True, + in_process: bool = False, + xgb_server_ready_timeout: float = Constant.XGB_SERVER_READY_TIMEOUT, ): + """Constructor method to initialize the object. + + Args: + int_client_grpc_options: An optional list of key-value pairs (`channel_arguments` + in gRPC Core runtime) to configure the gRPC channel of internal `GrpcClient`. + in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not. + xgb_server_ready_timeout (float): Duration for which the internal `GrpcClient` + should wait for the XGBoost gRPC server before timing out. + """ XGBServerAdaptor.__init__(self) self.int_client_grpc_options = int_client_grpc_options self.xgb_server_ready_timeout = xgb_server_ready_timeout diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_client.py b/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_client.py index f4800d05cb..ea518099da 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_client.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_client.py @@ -32,7 +32,8 @@ def __init__( Args: server_addr: address of the gRPC server to connect to - grpc_options: gRPC options for the gRPC client + grpc_options: An optional list of key-value pairs (`channel_arguments` + in gRPC Core runtime) to configure the gRPC channel. """ if not grpc_options: grpc_options = GRPC_DEFAULT_OPTIONS diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_server.py b/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_server.py index 7cd773cdb0..e8d209003c 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_server.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/grpc/grpc_server.py @@ -42,7 +42,8 @@ def __init__( addr: the listening address of the server max_workers: max number of workers servicer: the servicer that is capable of processing XGB requests - grpc_options: gRPC options + grpc_options: An optional list of key-value pairs (`channel_arguments` + in gRPC Core runtime) to configure the gRPC channel. """ if not grpc_options: grpc_options = GRPC_DEFAULT_OPTIONS diff --git a/setup.cfg b/setup.cfg index 7b1b6e834f..f8a58ab106 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,7 @@ PT = torchvision SKLEARN = scikit-learn + pandas>=1.5.1 TRACKING = mlflow wandb diff --git a/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml b/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml index e578f70420..f3a87da6d7 100644 --- a/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml +++ b/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml @@ -43,3 +43,19 @@ tests: args: { server_model_names: ["server"] } - path: tests.integration_test.src.validators.NumpySAGResultValidator args: { expected_result: [ [ 4, 5, 6 ], [ 7, 8, 9 ], [ 10, 11, 12 ] ] } + - test_name: "run hello-ccwf" + # TODO: add a result validator for the "models" saved on client site (ccwf) + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job hello-ccwf/jobs/swarm_cse_numpy" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } diff --git a/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml b/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml new file mode 100644 index 0000000000..8b140c11ff --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml @@ -0,0 +1,95 @@ +n_servers: 1 +n_clients: 2 +additional_python_paths: +- ../../examples/advanced/xgboost +cleanup: true +jobs_root_dir: ../../examples/advanced/xgboost/histogram-based/jobs + + +tests: +- test_name: Test a simplified copy of job higgs_2_histogram_uniform_split_uniform_lr + for xgboost histogram-based example. + event_sequence: + - actions: + - submit_job higgs_2_histogram_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - cp ../../examples/advanced/xgboost/histogram-based/requirements.txt + ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - pip install -r ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - python3 ../../examples/advanced/xgboost/utils/prepare_data_split.py + --data_path /tmp/nvflare/dataset/HIGGS.csv + --site_num 2 + --size_total 110000 + --size_valid 10000 + --split_method uniform + --out_path "/tmp/nvflare/xgboost_higgs_dataset/2_uniform" + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py + --site_num 2 + --training_mode histogram + --split_method uniform + --lr_mode uniform + --nthread 16 + --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr_copy + +- test_name: Test a simplified copy of job higgs_2_histogram_v2_uniform_split_uniform_lr + for xgboost histogram-based example. + event_sequence: + - actions: + - submit_job higgs_2_histogram_v2_uniform_split_uniform_lr + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - cp ../../examples/advanced/xgboost/histogram-based/requirements.txt + ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - pip install -r ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py + --site_num 2 + --training_mode histogram_v2 + --split_method uniform + --lr_mode uniform + --nthread 16 + --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_v2_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_v2_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_v2_uniform_split_uniform_lr_copy diff --git a/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml b/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml new file mode 100644 index 0000000000..50cdfc3973 --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml @@ -0,0 +1,91 @@ +n_servers: 1 +n_clients: 5 +additional_python_paths: +- ../../examples/advanced/xgboost +cleanup: true +jobs_root_dir: ../../examples/advanced/xgboost/tree-based/jobs + + +tests: +- test_name: Test a simplified copy of job higgs_5_cyclic_uniform_split_uniform_lr + for xgboost tree-based example. + event_sequence: + - actions: + - submit_job higgs_5_cyclic_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - cp ../../examples/advanced/xgboost/tree-based/requirements.txt + ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - pip install -r ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - python3 ../../examples/advanced/xgboost/utils/prepare_data_split.py + --data_path /tmp/nvflare/dataset/HIGGS.csv + --site_num 5 + --size_total 110000 + --size_valid 10000 + --split_method uniform + --out_path "/tmp/nvflare/xgboost_higgs_dataset/5_uniform" + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py + --site_num 5 + --training_mode cyclic + --split_method uniform + --lr_mode uniform + --nthread 16 + --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr_copy + +- test_name: Test a simplified copy of job higgs_5_bagging_uniform_split_uniform_lr + for xgboost tree-based example. + event_sequence: + - actions: + - submit_job higgs_5_bagging_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py + --site_num 5 + --training_mode bagging + --split_method uniform + --lr_mode uniform + --nthread 16 + --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr_copy diff --git a/tests/integration_test/run_integration_tests.sh b/tests/integration_test/run_integration_tests.sh index 6f8e68a981..43e1581104 100755 --- a/tests/integration_test/run_integration_tests.sh +++ b/tests/integration_test/run_integration_tests.sh @@ -3,7 +3,7 @@ set -e PYTHONPATH="${PWD}/../.." -backends=(numpy tensorflow pytorch overseer ha auth preflight cifar auto stats) +backends=(numpy tensorflow pytorch overseer ha auth preflight cifar auto stats xgboost) usage() { @@ -58,6 +58,18 @@ run_overseer_test() eval "$cmd" } +run_xgb_test() +{ + echo "Running xgboost integration tests." + cmd="$cmd xgb_test.py" + echo "$cmd" + eval "$cmd" + echo "Running system integration tests with backend $m." + cmd="$prefix $cmd system_test.py" + echo "$cmd" + eval "$cmd" +} + run_system_test() { echo "Running system integration tests with backend $m." @@ -81,6 +93,8 @@ elif [[ $m == "overseer" ]]; then run_overseer_test elif [[ $m == "preflight" ]]; then run_preflight_check_test +elif [[ $m == "xgboost" ]]; then + run_xgb_test else run_system_test fi diff --git a/tests/integration_test/src/example.py b/tests/integration_test/src/example.py index f46d01b53c..58dbeeea35 100644 --- a/tests/integration_test/src/example.py +++ b/tests/integration_test/src/example.py @@ -17,7 +17,7 @@ class Example: - """This class represents a standardized example structure in NVFlare.""" + """This class represents a standardized example folder structure in NVFlare.""" def __init__( self, @@ -27,9 +27,37 @@ def __init__( additional_python_path: Optional[str] = None, prepare_data_script: Optional[str] = None, ): + """Constructor of Example. + + A standardized example folder looks like the following: + + .. code-block + + ./[example_root] + ./[jobs_folder_in_example] + ./job_name1 + ./job_name2 + ./job_name3 + ./[requirements] + ./[prepare_data_script] + + For example: + + .. code-block + + ./cifar10-sim + ./jobs + ./cifar10_central + ./cifar10_fedavg + ./cifar10_fedopt + ... + ./requirements.txt + ./prepare_data.sh + + """ self.root = os.path.abspath(root) if not os.path.exists(self.root): - raise FileNotFoundError("Example root directory does not exist.") + raise FileNotFoundError("Example's root directory does not exist.") self.name = os.path.basename(self.root) diff --git a/tests/integration_test/src/mock_xgb/__init__.py b/tests/integration_test/src/mock_xgb/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/integration_test/src/mock_xgb/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration_test/src/mock_xgb/aggr_servicer.py b/tests/integration_test/src/mock_xgb/aggr_servicer.py new file mode 100644 index 0000000000..01f0e2055c --- /dev/null +++ b/tests/integration_test/src/mock_xgb/aggr_servicer.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2 +from nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2_grpc import FederatedServicer +from nvflare.fuel.utils.obj_utils import get_logger + + +class ReqWaiter: + def __init__(self, exp_num_clients: int, exp_seq: int, exp_op): + self.exp_num_clients = exp_num_clients + self.exp_seq = exp_seq + self.exp_op = exp_op + self.reqs = {} + self.result = {} + self.waiter = threading.Event() + + def add_request(self, op: str, rank, seq, req): + if seq != self.exp_seq: + raise RuntimeError(f"expecting seq {self.exp_seq} from {rank=} but got {seq}") + + if op != self.exp_op: + raise RuntimeError(f"expecting op {self.exp_op} from {rank=} but got {op}") + + if rank in self.reqs: + raise RuntimeError(f"duplicate request from {op=} {rank=} {seq=}") + + self.reqs[rank] = req + + if isinstance(req, pb2.AllgatherRequest): + reply = pb2.AllgatherReply(receive_buffer=req.send_buffer) + elif isinstance(req, pb2.AllreduceRequest): + reply = pb2.AllreduceReply(receive_buffer=req.send_buffer) + elif isinstance(req, pb2.BroadcastRequest): + reply = pb2.BroadcastReply(receive_buffer=req.send_buffer) + else: + raise RuntimeError(f"unknown request type {type(req)}") + self.result[rank] = reply + if len(self.reqs) == self.exp_num_clients: + self.waiter.set() + + def wait(self, timeout): + return self.waiter.wait(timeout) + + +class AggrServicer(FederatedServicer): + def __init__(self, num_clients, aggr_timeout=10.0): + self.logger = get_logger(self) + self.num_clients = num_clients + self.aggr_timeout = aggr_timeout + self.req_lock = threading.Lock() + self.req_waiter = None + + def _wait_for_result(self, op, rank, seq, request): + with self.req_lock: + if not self.req_waiter: + self.logger.info(f"setting new waiter: {seq=} {op=}") + self.req_waiter = ReqWaiter( + exp_num_clients=self.num_clients, + exp_seq=seq, + exp_op=op, + ) + self.req_waiter.add_request(op, rank, seq, request) + if not self.req_waiter.wait(self.aggr_timeout): + self.logger.error(f"results not received from all ranks after {self.aggr_timeout} seconds") + self.logger.info(f"for {rank=}: results remaining: {self.req_waiter.result.keys()}") + with self.req_lock: + result = self.req_waiter.result.pop(rank, None) + if len(self.req_waiter.result) == 0: + self.logger.info("all results are retrieved - reset req_waiter to None") + self.req_waiter = None + return result + + def Allgather(self, request: pb2.AllgatherRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + op = "Allgather" + self.logger.info(f"got {op}: {seq=} {rank=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) + + def Allreduce(self, request: pb2.AllreduceRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + reduce_op = request.reduce_operation + data_type = request.data_type + op = "Allreduce" + self.logger.info(f"got {op}: {seq=} {rank=} {reduce_op=} {data_type=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) + + def Broadcast(self, request: pb2.BroadcastRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + root = request.root + op = "Broadcast" + self.logger.info(f"got {op}: {seq=} {rank=} {root=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) diff --git a/tests/integration_test/src/mock_xgb/mock_client_runner.py b/tests/integration_test/src/mock_xgb/mock_client_runner.py new file mode 100644 index 0000000000..5927ca1c8a --- /dev/null +++ b/tests/integration_test/src/mock_xgb/mock_client_runner.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2 +from nvflare.apis.fl_component import FLComponent +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant +from nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_client import GrpcClient +from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner + + +class MockClientRunner(XGBRunner, FLComponent): + def __init__(self): + FLComponent.__init__(self) + self.training_stopped = False + self.asked_to_stop = False + + def run(self, ctx: dict): + server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) + rank = ctx.get(Constant.RUNNER_CTX_RANK) + num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) + + client = GrpcClient(server_addr=server_addr) + client.start() + + rank = rank + seq = 0 + total_time = 0 + total_reqs = 0 + for i in range(num_rounds): + if self.asked_to_stop: + self.logger.info("training aborted") + self.training_stopped = True + return + + self.logger.info(f"Test round {i}") + data = os.urandom(1000000) + + self.logger.info("sending allgather") + start = time.time() + result = client.send_allgather(seq_num=seq + 1, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherReply): + self.logger.error(f"expect reply to be pb2.AllgatherReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("allgather result does not match request") + else: + self.logger.info("OK: allgather result matches request!") + + self.logger.info("sending allreduce") + start = time.time() + result = client.send_allreduce( + seq_num=seq + 3, + rank=rank, + data=data, + reduce_op=2, + data_type=2, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllreduceReply): + self.logger.error(f"expect reply to be pb2.AllreduceReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("allreduce result does not match request") + else: + self.logger.info("OK: allreduce result matches request!") + print("OK: allreduce result matches request!") + + self.logger.info("sending broadcast") + start = time.time() + result = client.send_broadcast( + seq_num=seq + 4, + rank=rank, + data=data, + root=3, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.BroadcastReply): + self.logger.error(f"expect reply to be pb2.BroadcastReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("ERROR: broadcast result does not match request") + else: + self.logger.info("OK: broadcast result matches request!") + + seq += 4 + time.sleep(1.0) + + time_per_req = total_time / total_reqs + self.logger.info(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + print(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + self.training_stopped = True + + def stop(self): + self.asked_to_stop = True + + def is_stopped(self) -> (bool, int): + return self.training_stopped, 0 diff --git a/tests/integration_test/src/mock_xgb/mock_controller.py b/tests/integration_test/src/mock_xgb/mock_controller.py new file mode 100644 index 0000000000..8d32a4d838 --- /dev/null +++ b/tests/integration_test/src/mock_xgb/mock_controller.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.xgboost.histogram_based_v2.adaptor_controller import XGBController +from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_server_adaptor import GrpcServerAdaptor +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant + +from .mock_server_runner import MockServerRunner + + +class MockXGBController(XGBController): + def __init__( + self, + num_rounds: int, + configure_task_name=Constant.CONFIG_TASK_NAME, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_name=Constant.START_TASK_NAME, + start_task_timeout=Constant.START_TASK_TIMEOUT, + job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + client_ranks=None, + int_client_grpc_options=None, + in_process=True, + ): + XGBController.__init__( + self, + adaptor_component_id="", + num_rounds=num_rounds, + configure_task_name=configure_task_name, + configure_task_timeout=configure_task_timeout, + start_task_name=start_task_name, + start_task_timeout=start_task_timeout, + job_status_check_interval=job_status_check_interval, + max_client_op_interval=max_client_op_interval, + progress_timeout=progress_timeout, + client_ranks=client_ranks, + ) + self.int_client_grpc_options = int_client_grpc_options + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = MockServerRunner() + runner.initialize(fl_ctx) + adaptor = GrpcServerAdaptor( + int_client_grpc_options=self.int_client_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/tests/integration_test/src/mock_xgb/mock_executor.py b/tests/integration_test/src/mock_xgb/mock_executor.py new file mode 100644 index 0000000000..9a4bf87acc --- /dev/null +++ b/tests/integration_test/src/mock_xgb/mock_executor.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.xgboost.histogram_based_v2.adaptor_executor import XGBExecutor +from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor + +from .mock_client_runner import MockClientRunner + + +class MockXGBExecutor(XGBExecutor): + def __init__( + self, + int_server_grpc_options=None, + req_timeout=10.0, + in_process=True, + ): + XGBExecutor.__init__( + self, + adaptor_component_id="", + req_timeout=req_timeout, + ) + self.int_server_grpc_options = int_server_grpc_options + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = MockClientRunner() + runner.initialize(fl_ctx) + adaptor = GrpcClientAdaptor( + int_server_grpc_options=self.int_server_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/tests/integration_test/src/mock_xgb/mock_server_runner.py b/tests/integration_test/src/mock_xgb/mock_server_runner.py new file mode 100644 index 0000000000..df701c7cf0 --- /dev/null +++ b/tests/integration_test/src/mock_xgb/mock_server_runner.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant +from nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_server import GrpcServer +from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner + +from .aggr_servicer import AggrServicer + + +class MockServerRunner(XGBRunner): + def __init__(self, server_max_workers=10): + self.server_max_workers = server_max_workers + self._stopped = False + self._server = None + + def run(self, ctx: dict): + world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) + addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) + + self._server = GrpcServer( + addr, + max_workers=self.server_max_workers, + grpc_options=None, + servicer=AggrServicer(num_clients=world_size), + ) + self._server.start(no_blocking=False) + + def stop(self): + s = self._server + self._server = None + if s: + s.shutdown() + self._stopped = True + + def is_stopped(self) -> (bool, int): + return self._stopped, 0 diff --git a/tests/integration_test/src/mock_xgb/run_client.py b/tests/integration_test/src/mock_xgb/run_client.py new file mode 100644 index 0000000000..8055bd659f --- /dev/null +++ b/tests/integration_test/src/mock_xgb/run_client.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time + +import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2 +from nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_client import GrpcClient + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--rank", "-r", type=int, help="client rank", required=True) + parser.add_argument("--num_rounds", "-n", type=int, help="number of rounds", required=True) + + args = parser.parse_args() + client = GrpcClient(server_addr=args.addr) + client.start() + + rank = args.rank + seq = 0 + total_time = 0 + total_reqs = 0 + for i in range(args.num_rounds): + print(f"Test round {i}") + data = os.urandom(1000000) + + print("sending allgather") + start = time.time() + result = client.send_allgather(seq_num=seq + 1, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherReply): + print(f"expect reply to be pb2.AllgatherReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: allgather result does not match request") + else: + print("OK: allgather result matches request!") + + print("sending allreduce") + start = time.time() + result = client.send_allreduce( + seq_num=seq + 3, + rank=rank, + data=data, + reduce_op=2, + data_type=2, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllreduceReply): + print(f"expect reply to be pb2.AllreduceReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: allreduce result does not match request") + else: + print("OK: allreduce result matches request!") + + print("sending broadcast") + start = time.time() + result = client.send_broadcast( + seq_num=seq + 4, + rank=rank, + data=data, + root=3, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.BroadcastReply): + print(f"expect reply to be pb2.BroadcastReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: broadcast result does not match request") + else: + print("OK: broadcast result matches request!") + + seq += 4 + time.sleep(1.0) + + time_per_req = total_time / total_reqs + print(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/src/mock_xgb/run_server.py b/tests/integration_test/src/mock_xgb/run_server.py new file mode 100644 index 0000000000..6df116ffd3 --- /dev/null +++ b/tests/integration_test/src/mock_xgb/run_server.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +from aggr_servicer import AggrServicer + +from nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_server import GrpcServer + + +def main(): + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--num_clients", "-c", type=int, help="number of clients", required=True) + parser.add_argument("--max_workers", "-w", type=int, help="max number of workers", required=False, default=20) + + args = parser.parse_args() + print(f"starting server at {args.addr} max_workers={args.max_workers}") + server = GrpcServer( + args.addr, + max_workers=args.max_workers, + grpc_options=None, + servicer=AggrServicer(num_clients=args.num_clients), + ) + server.start() + + +if __name__ == "__main__": + main() diff --git a/tests/integration_test/src/utils.py b/tests/integration_test/src/utils.py index ac0d09291e..0a945607c9 100644 --- a/tests/integration_test/src/utils.py +++ b/tests/integration_test/src/utils.py @@ -292,6 +292,7 @@ def _replace_config_fed_client(client_json_path: str): with open(client_json_path, "r+") as f: config_fed_client = json.load(f) config_fed_client["TRAIN_SPLIT_ROOT"] = "/tmp/nvflare/test_data" + config_fed_client["num_rounds"] = 2 config_fed_client["AGGREGATION_EPOCHS"] = 1 f.seek(0) json.dump(config_fed_client, f, indent=4) @@ -318,67 +319,84 @@ def simplify_job(job_folder_path: str, postfix: str = POSTFIX): def generate_test_config_yaml_for_example( example: Example, project_yaml: str = PROJECT_YAML, - postfix: str = POSTFIX, + job_postfix: str = POSTFIX, ) -> List[str]: - """Generates test configuration yaml for NVFlare example. + """Generates test configurations for an NVFlare example folder. Args: - example: A well-formatted NVFlare example. - project_yaml: Project yaml file for the testing of this example. - postfix: Postfix for the newly generated job. + example (Example): A well-formatted NVFlare example folder. + project_yaml (str): Project yaml file for the testing of this example. + job_postfix (str): Postfix for the newly generated job. """ - output_yamls = [] os.makedirs(OUTPUT_YAML_DIR, exist_ok=True) for job in os.listdir(example.jobs_root_dir): - output_yaml = os.path.join(OUTPUT_YAML_DIR, f"{example.name}_{job}.yml") - job_dir = os.path.join(example.jobs_root_dir, job) - requirements_file = os.path.join(example.root, example.requirements_file) - new_requirements_file = os.path.join(example.root, "temp_requirements.txt") - exclude_requirements = "\\|".join(REQUIREMENTS_TO_EXCLUDE) - - setup = [ - f"cp {requirements_file} {new_requirements_file}", - f"sed -i '/{exclude_requirements}/d' {new_requirements_file}", - f"pip install -r {new_requirements_file}", - ] - if example.prepare_data_script is not None: - setup.append(f"bash {example.prepare_data_script}") - setup.append(f"python convert_to_test_job.py --job {job_dir} --post {postfix}") - setup.append(f"rm -f {new_requirements_file}") - - config = { - "ha": True, - "jobs_root_dir": example.jobs_root_dir, - "cleanup": True, - "project_yaml": project_yaml, - "additional_python_paths": example.additional_python_paths, - "tests": [ - { - "test_name": f"Test a simplified copy of job {job} for example {example.name}.", - "event_sequence": [ - { - "trigger": {"type": "server_log", "data": "Server started"}, - "actions": [f"submit_job {job}{postfix}"], - "result": {"type": "job_submit_success"}, - }, - { - "trigger": {"type": "run_state", "data": {"run_finished": True}}, - "actions": ["ensure_current_job_done"], - "result": {"type": "run_state", "data": {"run_finished": True}}, - }, - ], - "setup": setup, - "teardown": [f"rm -rf {job_dir}{postfix}"], - } - ], - } - with open(output_yaml, "w") as yaml_file: - yaml.dump(config, yaml_file, default_flow_style=False) + output_yaml = _generate_test_config_for_one_job(example, job, project_yaml, job_postfix) output_yamls.append(output_yaml) return output_yamls +def _generate_test_config_for_one_job( + example: Example, + job: str, + project_yaml: str = PROJECT_YAML, + postfix: str = POSTFIX, +) -> str: + """Generates test configuration yaml for an NVFlare example. + + Args: + example (Example): A well-formatted NVFlare example. + job (str): name of the job. + project_yaml (str): Project yaml file for the testing of this example. + postfix (str): Postfix for the newly generated job. + """ + output_yaml = os.path.join(OUTPUT_YAML_DIR, f"{example.name}_{job}.yml") + job_dir = os.path.join(example.jobs_root_dir, job) + requirements_file = os.path.join(example.root, example.requirements_file) + new_requirements_file = os.path.join(example.root, "temp_requirements.txt") + exclude_requirements = "\\|".join(REQUIREMENTS_TO_EXCLUDE) + + setup = [ + f"cp {requirements_file} {new_requirements_file}", + f"sed -i '/{exclude_requirements}/d' {new_requirements_file}", + f"pip install -r {new_requirements_file}", + ] + if example.prepare_data_script is not None: + setup.append(f"bash {example.prepare_data_script}") + setup.append(f"python convert_to_test_job.py --job {job_dir} --post {postfix}") + setup.append(f"rm -f {new_requirements_file}") + + config = { + "ha": True, + "jobs_root_dir": example.jobs_root_dir, + "cleanup": True, + "project_yaml": project_yaml, + "additional_python_paths": example.additional_python_paths, + "tests": [ + { + "test_name": f"Test a simplified copy of job {job} for example {example.name}.", + "event_sequence": [ + { + "trigger": {"type": "server_log", "data": "Server started"}, + "actions": [f"submit_job {job}{postfix}"], + "result": {"type": "job_submit_success"}, + }, + { + "trigger": {"type": "run_state", "data": {"run_finished": True}}, + "actions": ["ensure_current_job_done"], + "result": {"type": "run_state", "data": {"run_finished": True}}, + }, + ], + "setup": setup, + "teardown": [f"rm -rf {job_dir}{postfix}"], + } + ], + } + with open(output_yaml, "w") as yaml_file: + yaml.dump(config, yaml_file, default_flow_style=False) + return output_yaml + + def _read_admin_json_file(admin_json_file) -> dict: if not os.path.exists(admin_json_file): raise RuntimeError("Missing admin json file.") diff --git a/tests/integration_test/src/validators/np_sag_result_validator.py b/tests/integration_test/src/validators/np_sag_result_validator.py index 0e8690abb1..bbf6c94d63 100644 --- a/tests/integration_test/src/validators/np_sag_result_validator.py +++ b/tests/integration_test/src/validators/np_sag_result_validator.py @@ -20,9 +20,10 @@ class NumpySAGResultValidator(FinishJobResultValidator): - def __init__(self, expected_result): + def __init__(self, expected_result, model_name: str = "server.npy"): super().__init__() self.expected_result = np.array(expected_result) + self.model_name = model_name def validate_finished_results(self, job_result, client_props) -> bool: server_run_dir = job_result["workspace_root"] @@ -32,7 +33,7 @@ def validate_finished_results(self, job_result, client_props) -> bool: self.logger.error(f"models dir {models_dir} doesn't exist.") return False - model_path = os.path.join(models_dir, "server.npy") + model_path = os.path.join(models_dir, self.model_name) if not os.path.isfile(model_path): self.logger.error(f"model_path {model_path} doesn't exist.") return False diff --git a/tests/integration_test/test_configs.yml b/tests/integration_test/test_configs.yml index a179a6ce79..d75f59a784 100644 --- a/tests/integration_test/test_configs.yml +++ b/tests/integration_test/test_configs.yml @@ -29,3 +29,6 @@ test_configs: - ./data/test_configs/standalone_job/cifar_examples.yml stats: - ./data/test_configs/standalone_job/image_stats.yml + xgboost: + - ./data/test_configs/standalone_job/xgb_histogram_examples.yml + - ./data/test_configs/standalone_job/xgb_tree_examples.yml diff --git a/tests/integration_test/xgb_test.py b/tests/integration_test/xgb_test.py new file mode 100644 index 0000000000..46ae05e7ab --- /dev/null +++ b/tests/integration_test/xgb_test.py @@ -0,0 +1,68 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shlex +import subprocess +import sys + +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port + + +def run_command_in_subprocess(script_path: str, command_args: str): + new_env = os.environ.copy() + python_path = ":".join(sys.path)[1:] # strip leading colon + new_env["PYTHONPATH"] = f"{python_path}:{os.path.dirname(script_path)}" + process = subprocess.Popen( + shlex.split(f"python3 {script_path} {command_args}"), + env=new_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + return process + + +class TestXGB: + def test_xgb_mock_server_and_client(self): + num_clients = 2 + port = get_open_tcp_port({}) + server_addr = f"localhost:{port}" + integration_test_dir = os.path.dirname(os.path.abspath(__file__)) + server_script_path = os.path.join(integration_test_dir, "src", "mock_xgb", "run_server.py") + + server_process = run_command_in_subprocess( + script_path=server_script_path, command_args=f"-a {server_addr} -c {num_clients} -w 10" + ) + assert server_process is not None + + client_processes = {} + client_script_path = os.path.join(integration_test_dir, "src", "mock_xgb", "run_client.py") + for i in range(num_clients): + client_process = run_command_in_subprocess( + script_path=client_script_path, command_args=f"-a {server_addr} -r {i} -n {10}" + ) + client_processes[i] = client_process + assert client_process is not None + + for i in range(num_clients): + stdout, stderr = client_processes[i].communicate() + assert "ERROR" not in stdout.decode("utf-8") + assert stderr.decode("utf-8") == "" + client_processes[i].terminate() + server_process.terminate() + stdout, stderr = server_process.communicate() + assert "ERROR" not in stdout.decode("utf-8") + assert "ERROR" not in stderr.decode("utf-8") + print(f"Server Process Output (stdout):\n{stdout.decode('utf-8')}") + print(f"Server Process Output (stderr):\n{stderr.decode('utf-8')}") diff --git a/tests/unit_test/app_opt/xgboost/__init__.py b/tests/unit_test/app_opt/xgboost/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/__init__.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptor_test.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptor_test.py new file mode 100644 index 0000000000..17507b73dc --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptor_test.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_opt.xgboost.histogram_based_v2.adaptor import XGBAdaptor, XGBClientAdaptor, XGBServerAdaptor +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant +from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner +from nvflare.app_opt.xgboost.histogram_based_v2.sender import Sender + + +@patch.multiple(XGBAdaptor, __abstractmethods__=set()) +class TestXGBAdaptor: + def test_set_abort_signal(self): + xgb_adaptor = XGBAdaptor() + abort_signal = Signal() + xgb_adaptor.set_abort_signal(abort_signal) + abort_signal.trigger("cool") + assert xgb_adaptor.abort_signal.triggered + + @patch.multiple(XGBRunner, __abstractmethods__=set()) + def test_set_runner(self): + runner = XGBRunner() + xgb_adaptor = XGBAdaptor() + + xgb_adaptor.set_runner(runner) + + assert xgb_adaptor.xgb_runner == runner + + +class TestXGBServerAdaptor: + @patch.multiple(XGBServerAdaptor, __abstractmethods__=set()) + def test_configure(self): + xgb_adaptor = XGBServerAdaptor() + config = {Constant.CONF_KEY_WORLD_SIZE: 66} + ctx = FLContext() + xgb_adaptor.configure(config, ctx) + assert xgb_adaptor.world_size == 66 + + +@patch.multiple(XGBClientAdaptor, __abstractmethods__=set()) +class TestXGBClientAdaptor: + def test_configure(self): + xgb_adaptor = XGBClientAdaptor() + config = {Constant.CONF_KEY_WORLD_SIZE: 66, Constant.CONF_KEY_RANK: 44, Constant.CONF_KEY_NUM_ROUNDS: 100} + ctx = FLContext() + xgb_adaptor.configure(config, ctx) + assert xgb_adaptor.world_size == 66 + assert xgb_adaptor.rank == 44 + assert xgb_adaptor.num_rounds == 100 + + def test_send(self): + xgb_adaptor = XGBClientAdaptor() + sender = Mock(spec=Sender) + reply = Shareable() + reply[Constant.PARAM_KEY_RCV_BUF] = b"hello" + sender.send_to_server.return_value = reply + abort_signal = Signal() + xgb_adaptor.set_abort_signal(abort_signal) + xgb_adaptor.set_sender(sender) + assert xgb_adaptor.sender == sender + assert xgb_adaptor._send_request("", Shareable()) == b"hello" diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/__init__.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_client_adaptor_test.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_client_adaptor_test.py new file mode 100644 index 0000000000..96d2ad4b95 --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_client_adaptor_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_client_adaptor import GrpcClientAdaptor +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant + +from .mock_runner import MockRunner, wait_for_status + + +class TestGrpcClientAdaptor: + def test_start_and_stop(self): + runner = MockRunner() + adaptor = GrpcClientAdaptor(in_process=True) + config = {Constant.CONF_KEY_WORLD_SIZE: 66, Constant.CONF_KEY_RANK: 44, Constant.CONF_KEY_NUM_ROUNDS: 100} + ctx = FLContext() + adaptor.configure(config, ctx) + + adaptor.set_runner(runner) + with patch("nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_server.GrpcServer.start") as mock_method: + mock_method.return_value = True + adaptor.start(ctx) + assert wait_for_status(runner, True) + + adaptor.stop(ctx) + assert wait_for_status(runner, False) diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_server_adaptor_test.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_server_adaptor_test.py new file mode 100644 index 0000000000..a3f831d6de --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/grpc_server_adaptor_test.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +from nvflare.apis.fl_context import FLContext +from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.grpc_server_adaptor import GrpcServerAdaptor + +from .mock_runner import MockRunner, wait_for_status + + +class TestGrpcServerAdaptor: + def test_start_and_stop(self): + runner = MockRunner() + adaptor = GrpcServerAdaptor(in_process=True) + ctx = FLContext() + + adaptor.set_runner(runner) + with patch("nvflare.app_opt.xgboost.histogram_based_v2.grpc.grpc_client.GrpcClient.start") as mock_method: + mock_method.return_value = True + adaptor.start(ctx) + assert wait_for_status(runner, True) + + adaptor.stop(ctx) + assert wait_for_status(runner, False) diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/mock_runner.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/mock_runner.py new file mode 100644 index 0000000000..9218574e0f --- /dev/null +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/mock_runner.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Tuple + +from nvflare.app_opt.xgboost.histogram_based_v2.runner import XGBRunner + +TIMEOUT = 5.0 + + +def wait_for_status(runner: XGBRunner, expected, timeout=TIMEOUT): + start_time = time.time() + while runner.started != expected: + if time.time() - start_time > timeout: + return False + time.sleep(1.0) + return True + + +class MockRunner(XGBRunner): + def __init__(self): + self.started = False + + def run(self, ctx: dict): + self.started = True + while self.started: + time.sleep(1.0) + + def stop(self): + self.started = False + + def is_stopped(self) -> Tuple[bool, int]: + return self.started, 0