Skip to content

Commit

Permalink
[2.4] Add xgboost example, unit tests, integration tests (#2392)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Mar 8, 2024
1 parent b66da2a commit 4946ac6
Show file tree
Hide file tree
Showing 42 changed files with 1,388 additions and 96 deletions.
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ sphinx:
python:
install:
- method: pip
path: .[doc]
path: .[dev]
# system_packages: true

2 changes: 1 addition & 1 deletion build_doc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function clean_docs() {
}

function build_html_docs() {
pip install -e .[doc]
pip install -e .[dev]
sphinx-apidoc --module-first -f -o docs/apidocs/ nvflare "*poc" "*private"
sphinx-build -b html docs docs/_build
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
{
"format_version": 2,
"num_rounds": 100,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"id": "Executor",
"name": "FedXGBHistogramExecutor",
"path": "nvflare.app_opt.xgboost.histogram_based.executor.FedXGBHistogramExecutor",
"args": {
"data_loader_id": "dataloader",
"num_rounds": 100,
"num_rounds": "{num_rounds}",
"early_stopping_rounds": 2,
"xgb_params": {
"max_depth": 8,
"eta": 0.1,
"objective": "binary:logistic",
"eval_metric": "auc",
"tree_method": "gpu_hist",
"tree_method": "hist",
"nthread": 16
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"components": [],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"format_version": 2,
"num_rounds": 100,
"executors": [
{
"tasks": [
"config", "start"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_opt.xgboost.histogram_based_v2.executor.FedXGBHistogramExecutor",
"args": {
"data_loader_id": "dataloader",
"early_stopping_rounds": 2,
"xgb_params": {
"max_depth": 8,
"eta": 0.1,
"objective": "binary:logistic",
"eval_metric": "auc",
"tree_method": "hist",
"nthread": 16
}
}
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": [
{
"id": "dataloader",
"path": "higgs_data_loader.HIGGSDataLoader",
"args": {
"data_split_filename": "data_split.json"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"format_version": 2,
"num_rounds": 100,
"task_data_filters": [],
"task_result_filters": [],
"components": [],
"workflows": [
{
"id": "xgb_controller",
"path": "nvflare.app_opt.xgboost.histogram_based_v2.controller.XGBFedController",
"args": {
"num_rounds": "{num_rounds}"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

import pandas as pd
import xgboost as xgb

from nvflare.app_opt.xgboost.data_loader import XGBDataLoader


def _read_higgs_with_pandas(data_path, start: int, end: int):
data_size = end - start
data = pd.read_csv(data_path, header=None, skiprows=start, nrows=data_size)
data_num = data.shape[0]

# split to feature and label
x = data.iloc[:, 1:].copy()
y = data.iloc[:, 0].copy()

return x, y, data_num


class HIGGSDataLoader(XGBDataLoader):
def __init__(self, data_split_filename):
"""Reads HIGGS dataset and return XGB data matrix.
Args:
data_split_filename: file name to data splits
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
valid_index = data_index["valid"]

# training
x_train, y_train, total_train_data_num = _read_higgs_with_pandas(
data_path=data_path, start=site_index["start"], end=site_index["end"]
)
dmat_train = xgb.DMatrix(x_train, label=y_train)

# validation
x_valid, y_valid, total_valid_data_num = _read_higgs_with_pandas(
data_path=data_path, start=valid_index["start"], end=valid_index["end"]
)
dmat_valid = xgb.DMatrix(x_valid, label=y_valid)

return dmat_train, dmat_valid
10 changes: 10 additions & 0 deletions examples/advanced/xgboost/histogram-based/jobs/base_v2/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "xgboost_histogram_based_v2",
"resource_spec": {},
"deploy_map": {
"app": [
"@ALL"
]
},
"min_clients": 2
}
2 changes: 2 additions & 0 deletions examples/advanced/xgboost/prepare_job_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ prepare_job_config 20 cyclic uniform uniform $TREE_METHOD

prepare_job_config 2 histogram uniform uniform $TREE_METHOD
prepare_job_config 5 histogram uniform uniform $TREE_METHOD
prepare_job_config 2 histogram_v2 uniform uniform $TREE_METHOD
prepare_job_config 5 histogram_v2 uniform uniform $TREE_METHOD
echo "Job configs generated"
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
{
"format_version": 2,

"server": {
"heart_beat_timeout": 600,
"task_request_interval": 0.05
},

"num_rounds": 101,
"task_data_filters": [],
"task_result_filters": [],

Expand Down Expand Up @@ -34,7 +29,7 @@
"name": "ScatterAndGather",
"args": {
"min_clients": 5,
"num_rounds": 101,
"num_rounds": "{num_rounds}",
"start_round": 0,
"wait_time_after_min_received": 0,
"aggregator_id": "aggregator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"data_loader_id": "dataloader",
"training_mode": "cyclic",
"num_client_bagging": 1,
"lr_mode": "scaled",
"local_model_path": "model.json",
"global_model_path": "model_global.json",
"learning_rate": 0.1,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
{
"format_version": 2,

"server": {
"heart_beat_timeout": 600,
"task_request_interval": 0.05
},

"num_rounds": 20,
"task_data_filters": [],
"task_result_filters": [],

Expand All @@ -29,7 +24,7 @@
"id": "cyclic_ctl",
"name": "CyclicController",
"args": {
"num_rounds": 20,
"num_rounds": "{num_rounds}",
"task_assignment_timeout": 60,
"persistor_id": "persistor",
"shareable_generator_id": "shareable_generator",
Expand Down
24 changes: 14 additions & 10 deletions examples/advanced/xgboost/utils/prepare_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,16 @@

from nvflare.apis.fl_constant import JobConstants

SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__))
XGB_EXAMPLE_ROOT = SCRIPT_PATH.parent.parent.absolute()
JOB_CONFIGS_ROOT = "jobs"
MODE_ALGO_MAP = {"bagging": "tree-based", "cyclic": "tree-based", "histogram": "histogram-based"}
MODE_ALGO_MAP = {
"bagging": "tree-based",
"cyclic": "tree-based",
"histogram": "histogram-based",
"histogram_v2": "histogram-based",
}
BASE_JOB_MAP = {"bagging": "bagging_base", "cyclic": "cyclic_base", "histogram": "base", "histogram_v2": "base_v2"}


def job_config_args_parser():
Expand Down Expand Up @@ -79,12 +87,7 @@ def _get_data_split_name(args, site_name: str) -> str:


def _get_src_job_dir(training_mode):
base_job_map = {
"bagging": "bagging_base",
"cyclic": "cyclic_base",
"histogram": "base",
}
return pathlib.Path(MODE_ALGO_MAP[training_mode]) / JOB_CONFIGS_ROOT / base_job_map[training_mode]
return XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[training_mode] / JOB_CONFIGS_ROOT / BASE_JOB_MAP[training_mode]


def _gen_deploy_map(num_sites: int, site_name_prefix: str) -> dict:
Expand Down Expand Up @@ -133,17 +136,18 @@ def _update_client_config(config: dict, args, lr_scale, site_name: str):
num_client_bagging = args.site_num
config["executors"][0]["executor"]["args"]["num_client_bagging"] = num_client_bagging
else:
config["num_rounds"] = args.round_num
config["components"][0]["args"]["data_split_filename"] = data_split_name
config["executors"][0]["executor"]["args"]["xgb_params"]["nthread"] = args.nthread
config["executors"][0]["executor"]["args"]["xgb_params"]["tree_method"] = args.tree_method


def _update_server_config(config: dict, args):
if args.training_mode == "bagging":
config["workflows"][0]["args"]["num_rounds"] = args.round_num + 1
config["num_rounds"] = args.round_num + 1
config["workflows"][0]["args"]["min_clients"] = args.site_num
elif args.training_mode == "cyclic":
config["workflows"][0]["args"]["num_rounds"] = int(args.round_num / args.site_num)
config["num_rounds"] = int(args.round_num / args.site_num)


def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name):
Expand Down Expand Up @@ -198,7 +202,7 @@ def main():
src_job_path = _get_src_job_dir(args.training_mode)

# create a new job
dst_job_path = pathlib.Path(MODE_ALGO_MAP[args.training_mode]) / JOB_CONFIGS_ROOT / job_name
dst_job_path = XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[args.training_mode] / JOB_CONFIGS_ROOT / job_name
if not os.path.exists(dst_job_path):
os.makedirs(dst_job_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import multiprocessing
import sys
import threading
from typing import Tuple

Expand Down Expand Up @@ -59,11 +58,39 @@ def start(self, ctx: dict):


class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer):
"""Implementation of XGBClientAdaptor that uses an internal `GrpcServer`.
The `GrpcClientAdaptor` class serves as an interface between the XGBoost
federated client and federated server components.
It employs its `XGBRunner` to initiate an XGBoost federated gRPC client
and utilizes an internal `GrpcServer` to forward client requests/responses.
The communication flow is as follows:
1. XGBoost federated gRPC client talks to `GrpcClientAdaptor`, which
encapsulates a `GrpcServer`.
Requests are then forwarded to `GrpcServerAdaptor`, which internally
manages a `GrpcClient` responsible for interacting with the XGBoost
federated gRPC server.
2. XGBoost federated gRPC server talks to `GrpcServerAdaptor`, which
encapsulates a `GrpcClient`.
Responses are then forwarded to `GrpcClientAdaptor`, which internally
manages a `GrpcServer` responsible for interacting with the XGBoost
federated gRPC client.
"""

def __init__(
self,
int_server_grpc_options=None,
in_process=False,
):
"""Constructor method to initialize the object.
Args:
int_server_grpc_options: An optional list of key-value pairs (`channel_arguments`
in gRPC Core runtime) to configure the gRPC channel of internal `GrpcServer`.
in_process (bool): Specifies whether to start the `XGBRunner` in the same process or not.
"""
XGBClientAdaptor.__init__(self)
self.int_server_grpc_options = int_server_grpc_options
self.in_process = in_process
Expand Down Expand Up @@ -212,7 +239,6 @@ def _abort(self, reason: str):

def Allgather(self, request: pb2.AllgatherRequest, context):
try:
self.logger.info(f"Calling Allgather with {sys.getsizeof(request.send_buffer)}")
rcv_buf = self._send_all_gather(
rank=request.rank,
seq=request.sequence_number,
Expand All @@ -225,7 +251,6 @@ def Allgather(self, request: pb2.AllgatherRequest, context):

def Allreduce(self, request: pb2.AllreduceRequest, context):
try:
self.logger.info(f"Calling Allreduce with {sys.getsizeof(request.send_buffer)}")
rcv_buf = self._send_all_reduce(
rank=request.rank,
seq=request.sequence_number,
Expand All @@ -240,7 +265,6 @@ def Allreduce(self, request: pb2.AllreduceRequest, context):

def Broadcast(self, request: pb2.BroadcastRequest, context):
try:
self.logger.info(f"Calling Broadcast with {sys.getsizeof(request.send_buffer)}")
rcv_buf = self._send_broadcast(
rank=request.rank,
seq=request.sequence_number,
Expand Down
Loading

0 comments on commit 4946ac6

Please sign in to comment.