Skip to content

Commit

Permalink
[2.4] Ported newest ReliableMessage to 2.4 (#2815)
Browse files Browse the repository at this point in the history
* Ported 2.5 ReliableMessage to 2.4

* Fixed base_v2 to work with XGBoost 2.11

* Updated copyright year

* Updated doc to mention v2.11

* Corrected version number to 2.1.1
  • Loading branch information
nvidianz authored Aug 21, 2024
1 parent b9d01cb commit 53ecfd5
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 163 deletions.
2 changes: 2 additions & 0 deletions examples/advanced/xgboost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ These examples show how to use [NVIDIA FLARE](https://nvflare.readthedocs.io/en/
They use [XGBoost](https://github.com/dmlc/xgboost),
which is an optimized distributed gradient boosting library.

The code was tested with XGBoost V2.1.1. It may not work with other versions of XGBoost.

### HIGGS
The examples illustrate a binary classification task based on [HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs).
This dataset contains 11 million instances, each with 28 attributes.
Expand Down
255 changes: 176 additions & 79 deletions nvflare/apis/utils/reliable_message.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion nvflare/app_opt/xgboost/histogram_based/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -
self.log_info(fl_ctx, f"server address is {self._server_address}")

communicator_env = {
"xgboost_communicator": "federated",
"dmlc_communicator": "federated",
"federated_server_address": f"{self._server_address}:{xgb_fl_server_port}",
"federated_world_size": self.world_size,
"federated_rank": self.rank,
Expand Down
12 changes: 2 additions & 10 deletions nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,15 @@ def start_controller(self, fl_ctx: FLContext):
adaptor.initialize(fl_ctx)
self.adaptor = adaptor

engine = fl_ctx.get_engine()
engine.register_aux_message_handler(
topic=Constant.TOPIC_XGB_REQUEST,
message_handle_func=self._process_xgb_request,
)
engine.register_aux_message_handler(
topic=Constant.TOPIC_CLIENT_DONE,
message_handle_func=self._process_client_done,
)

ReliableMessage.register_request_handler(
topic=Constant.TOPIC_XGB_REQUEST,
handler_f=self._process_xgb_request,
fl_ctx=fl_ctx,
)
ReliableMessage.register_request_handler(
topic=Constant.TOPIC_CLIENT_DONE,
handler_f=self._process_client_done,
fl_ctx=fl_ctx,
)

def _trigger_stop(self, fl_ctx: FLContext, error=None):
Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_opt/xgboost/histogram_based_v2/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Constant:
CONF_KEY_NUM_ROUNDS = "num_rounds"

# default component config values
CONFIG_TASK_TIMEOUT = 10
CONFIG_TASK_TIMEOUT = 60
START_TASK_TIMEOUT = 10
XGB_SERVER_READY_TIMEOUT = 5.0

Expand Down
39 changes: 27 additions & 12 deletions nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
/*!
* Copyright 2022 XGBoost contributors
* needs to match file in https://github.com/dmlc/xgboost/blob/v2.0.3/plugin/federated/federated.proto
* Copyright 2022-2023 XGBoost contributors
*/
syntax = "proto3";

package xgboost.federated;
package xgboost.collective.federated;

service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}

enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
HALF = 0;
FLOAT = 1;
DOUBLE = 2;
LONG_DOUBLE = 3;
INT8 = 4;
INT16 = 5;
INT32 = 6;
INT64 = 7;
UINT8 = 8;
UINT16 = 9;
UINT32 = 10;
UINT64 = 11;
}

enum ReduceOperation {
Expand All @@ -43,6 +47,17 @@ message AllgatherReply {
bytes receive_buffer = 1;
}

message AllgatherVRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}

message AllgatherVReply {
bytes receive_buffer = 1;
}

message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
Expand All @@ -67,4 +82,4 @@ message BroadcastRequest {

message BroadcastReply {
bytes receive_buffer = 1;
}
}
45 changes: 25 additions & 20 deletions nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: federated.proto
# Protobuf Python Version: 4.25.0
# Protobuf Python Version: 4.25.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
Expand All @@ -27,29 +28,33 @@



DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x11xgboost.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xbc\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12.\n\tdata_type\x18\x04 \x01(\x0e\x32\x1b.xgboost.federated.DataType\x12<\n\x10reduce_operation\x18\x05 \x01(\x0e\x32\".xgboost.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*d\n\x08\x44\x61taType\x12\x08\n\x04INT8\x10\x00\x12\t\n\x05UINT8\x10\x01\x12\t\n\x05INT32\x10\x02\x12\n\n\x06UINT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06UINT64\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\x90\x02\n\tFederated\x12U\n\tAllgather\x12#.xgboost.federated.AllgatherRequest\x1a!.xgboost.federated.AllgatherReply\"\x00\x12U\n\tAllreduce\x12#.xgboost.federated.AllreduceRequest\x1a!.xgboost.federated.AllreduceReply\"\x00\x12U\n\tBroadcast\x12#.xgboost.federated.BroadcastRequest\x1a!.xgboost.federated.BroadcastReply\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3')

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals['_DATATYPE']._serialized_start=529
_globals['_DATATYPE']._serialized_end=629
_globals['_REDUCEOPERATION']._serialized_start=631
_globals['_REDUCEOPERATION']._serialized_end=725
_globals['_ALLGATHERREQUEST']._serialized_start=38
_globals['_ALLGATHERREQUEST']._serialized_end=116
_globals['_ALLGATHERREPLY']._serialized_start=118
_globals['_ALLGATHERREPLY']._serialized_end=158
_globals['_ALLREDUCEREQUEST']._serialized_start=161
_globals['_ALLREDUCEREQUEST']._serialized_end=349
_globals['_ALLREDUCEREPLY']._serialized_start=351
_globals['_ALLREDUCEREPLY']._serialized_end=391
_globals['_BROADCASTREQUEST']._serialized_start=393
_globals['_BROADCASTREQUEST']._serialized_end=485
_globals['_BROADCASTREPLY']._serialized_start=487
_globals['_BROADCASTREPLY']._serialized_end=527
_globals['_FEDERATED']._serialized_start=728
_globals['_FEDERATED']._serialized_end=1000
_globals['_DATATYPE']._serialized_start=687
_globals['_DATATYPE']._serialized_end=837
_globals['_REDUCEOPERATION']._serialized_start=839
_globals['_REDUCEOPERATION']._serialized_end=933
_globals['_ALLGATHERREQUEST']._serialized_start=49
_globals['_ALLGATHERREQUEST']._serialized_end=127
_globals['_ALLGATHERREPLY']._serialized_start=129
_globals['_ALLGATHERREPLY']._serialized_end=169
_globals['_ALLGATHERVREQUEST']._serialized_start=171
_globals['_ALLGATHERVREQUEST']._serialized_end=250
_globals['_ALLGATHERVREPLY']._serialized_start=252
_globals['_ALLGATHERVREPLY']._serialized_end=293
_globals['_ALLREDUCEREQUEST']._serialized_start=296
_globals['_ALLREDUCEREQUEST']._serialized_end=506
_globals['_ALLREDUCEREPLY']._serialized_start=508
_globals['_ALLREDUCEREPLY']._serialized_end=548
_globals['_BROADCASTREQUEST']._serialized_start=550
_globals['_BROADCASTREQUEST']._serialized_end=642
_globals['_BROADCASTREPLY']._serialized_start=644
_globals['_BROADCASTREPLY']._serialized_end=684
_globals['_FEDERATED']._serialized_start=936
_globals['_FEDERATED']._serialized_end=1386
# @@protoc_insertion_point(module_scope)
40 changes: 32 additions & 8 deletions nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ DESCRIPTOR: _descriptor.FileDescriptor

class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
HALF: _ClassVar[DataType]
FLOAT: _ClassVar[DataType]
DOUBLE: _ClassVar[DataType]
LONG_DOUBLE: _ClassVar[DataType]
INT8: _ClassVar[DataType]
UINT8: _ClassVar[DataType]
INT16: _ClassVar[DataType]
INT32: _ClassVar[DataType]
UINT32: _ClassVar[DataType]
INT64: _ClassVar[DataType]
UINT8: _ClassVar[DataType]
UINT16: _ClassVar[DataType]
UINT32: _ClassVar[DataType]
UINT64: _ClassVar[DataType]
FLOAT: _ClassVar[DataType]
DOUBLE: _ClassVar[DataType]

class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
Expand All @@ -24,14 +28,18 @@ class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
BITWISE_AND: _ClassVar[ReduceOperation]
BITWISE_OR: _ClassVar[ReduceOperation]
BITWISE_XOR: _ClassVar[ReduceOperation]
HALF: DataType
FLOAT: DataType
DOUBLE: DataType
LONG_DOUBLE: DataType
INT8: DataType
UINT8: DataType
INT16: DataType
INT32: DataType
UINT32: DataType
INT64: DataType
UINT8: DataType
UINT16: DataType
UINT32: DataType
UINT64: DataType
FLOAT: DataType
DOUBLE: DataType
MAX: ReduceOperation
MIN: ReduceOperation
SUM: ReduceOperation
Expand All @@ -55,6 +63,22 @@ class AllgatherReply(_message.Message):
receive_buffer: bytes
def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ...

class AllgatherVRequest(_message.Message):
__slots__ = ("sequence_number", "rank", "send_buffer")
SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int]
RANK_FIELD_NUMBER: _ClassVar[int]
SEND_BUFFER_FIELD_NUMBER: _ClassVar[int]
sequence_number: int
rank: int
send_buffer: bytes
def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ...

class AllgatherVReply(_message.Message):
__slots__ = ("receive_buffer",)
RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int]
receive_buffer: bytes
def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ...

class AllreduceRequest(_message.Message):
__slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation")
SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int]
Expand Down
Loading

0 comments on commit 53ecfd5

Please sign in to comment.