diff --git a/examples/advanced/xgboost/README.md b/examples/advanced/xgboost/README.md index 9f2e564779..e6e86d96b1 100644 --- a/examples/advanced/xgboost/README.md +++ b/examples/advanced/xgboost/README.md @@ -11,6 +11,8 @@ These examples show how to use [NVIDIA FLARE](https://nvflare.readthedocs.io/en/ They use [XGBoost](https://github.com/dmlc/xgboost), which is an optimized distributed gradient boosting library. +The code was tested with XGBoost V2.1.1. It may not work with other versions of XGBoost. + ### HIGGS The examples illustrate a binary classification task based on [HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs). This dataset contains 11 million instances, each with 28 attributes. diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index bfa76e3dac..2551a28cb9 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -51,6 +51,8 @@ TOPIC_RELIABLE_REPLY = "RM.RELIABLE_REPLY" PROP_KEY_TX_ID = "RM.TX_ID" +PROP_KEY_TOPIC = "RM.TOPIC" +PROP_KEY_OP = "RM.OP" def _extract_result(reply: dict, target: str): @@ -94,61 +96,73 @@ def __init__(self, topic, request_handler_f, executor, per_msg_timeout, tx_timeo self.tx_id = None self.reply_time = None self.replying = False + self.lock = threading.Lock() def process(self, request: Shareable, fl_ctx: FLContext) -> Shareable: - self.tx_id = request.get_header(HEADER_TX_ID) - op = request.get_header(HEADER_OP) - peer_ctx = fl_ctx.get_peer_context() - assert isinstance(peer_ctx, FLContext) - self.source = peer_ctx.get_identity_name() - if op == OP_REQUEST: - # it is possible that a new request for the same tx is received while we are processing the previous one - if not self.rcv_time: - self.rcv_time = time.time() - self.per_msg_timeout = request.get_header(HEADER_PER_MSG_TIMEOUT) - self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT) - - # start processing - ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}") - self.executor.submit(self._do_request, request, fl_ctx) - return _status_reply(STATUS_IN_PROCESS) # ack - elif self.result: - # we already finished processing - send the result back - ReliableMessage.info(fl_ctx, "resend result back to requester") - return self.result - else: - # we are still processing - ReliableMessage.info(fl_ctx, "got request - the request is being processed") - return _status_reply(STATUS_IN_PROCESS) - elif op == OP_QUERY: - if self.result: - if self.reply_time: - # result already sent back successfully - ReliableMessage.info(fl_ctx, "got query: we already replied successfully") - return _status_reply(STATUS_REPLIED) - elif self.replying: - # result is being sent - ReliableMessage.info(fl_ctx, "got query: reply is being sent") - return _status_reply(STATUS_IN_REPLY) - else: - # try to send the result again - ReliableMessage.info(fl_ctx, "got query: sending reply again") + if not ReliableMessage.is_available(): + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + with self.lock: + self.tx_id = request.get_header(HEADER_TX_ID) + op = request.get_header(HEADER_OP) + peer_ctx = fl_ctx.get_peer_context() + assert isinstance(peer_ctx, FLContext) + self.source = peer_ctx.get_identity_name() + if op == OP_REQUEST: + # it is possible that a new request for the same tx is received while we are processing the previous one + if not self.rcv_time: + self.rcv_time = time.time() + self.per_msg_timeout = request.get_header(HEADER_PER_MSG_TIMEOUT) + self.tx_timeout = request.get_header(HEADER_TX_TIMEOUT) + + # start processing + ReliableMessage.info(fl_ctx, f"started processing request of topic {self.topic}") + try: + self.executor.submit(self._do_request, request, fl_ctx) + return _status_reply(STATUS_IN_PROCESS) # ack + except Exception as ex: + # it is possible that the RM is already closed (self.executor is shut down) + ReliableMessage.error(fl_ctx, f"failed to submit request: {secure_format_exception(ex)}") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + elif self.result: + # we already finished processing - send the result back + ReliableMessage.info(fl_ctx, "resend result back to requester") return self.result - else: - # still in process - if time.time() - self.rcv_time > self.tx_timeout: - # the process is taking too much time - ReliableMessage.error(fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}") - return _status_reply(STATUS_ABORTED) else: - ReliableMessage.info(fl_ctx, "got query: request is in-process") + # we are still processing + ReliableMessage.info(fl_ctx, "got request - the request is being processed") return _status_reply(STATUS_IN_PROCESS) + elif op == OP_QUERY: + if self.result: + if self.reply_time: + # result already sent back successfully + ReliableMessage.info(fl_ctx, "got query: we already replied successfully") + return _status_reply(STATUS_REPLIED) + elif self.replying: + # result is being sent + ReliableMessage.info(fl_ctx, "got query: reply is being sent") + return _status_reply(STATUS_IN_REPLY) + else: + # try to send the result again + ReliableMessage.info(fl_ctx, "got query: sending reply again") + return self.result + else: + # still in process + if time.time() - self.rcv_time > self.tx_timeout: + # the process is taking too much time + ReliableMessage.error( + fl_ctx, f"aborting processing since exceeded max tx time {self.tx_timeout}" + ) + return _status_reply(STATUS_ABORTED) + else: + ReliableMessage.debug(fl_ctx, "got query: request is in-process") + return _status_reply(STATUS_IN_PROCESS) def _try_reply(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() self.replying = True start_time = time.time() - ReliableMessage.info(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}") + ReliableMessage.debug(fl_ctx, f"try to send reply back to {self.source}: {self.per_msg_timeout=}") ack = engine.send_aux_request( targets=[self.source], topic=TOPIC_RELIABLE_REPLY, @@ -162,15 +176,20 @@ def _try_reply(self, fl_ctx: FLContext): if rc == ReturnCode.OK: # reply sent successfully! self.reply_time = time.time() - ReliableMessage.info(fl_ctx, f"sent reply successfully in {time_spent} secs") + ReliableMessage.debug(fl_ctx, f"sent reply successfully in {time_spent} secs") + + # release the receiver kept by the ReliableMessage! + ReliableMessage.release_request_receiver(self, fl_ctx) else: + # unsure whether the reply was sent successfully + # do not release the request receiver in case the requester asks for result in a query ReliableMessage.error( fl_ctx, f"failed to send reply in {time_spent} secs: {rc=}; will wait for requester to query" ) def _do_request(self, request: Shareable, fl_ctx: FLContext): start_time = time.time() - ReliableMessage.info(fl_ctx, "invoking request handler") + ReliableMessage.debug(fl_ctx, "invoking request handler") try: result = self.request_handler_f(self.topic, request, fl_ctx) except Exception as e: @@ -182,11 +201,13 @@ def _do_request(self, request: Shareable, fl_ctx: FLContext): result.set_header(HEADER_OP, OP_REPLY) result.set_header(HEADER_TOPIC, self.topic) self.result = result - ReliableMessage.info(fl_ctx, f"finished request handler in {time.time()-start_time} secs") + ReliableMessage.debug(fl_ctx, f"finished request handler in {time.time()-start_time} secs") self._try_reply(fl_ctx) class _ReplyReceiver: + """This class handles reliable message replies on the sending end""" + def __init__(self, tx_id: str, per_msg_timeout: float, tx_timeout: float): self.tx_id = tx_id self.tx_start_time = time.time() @@ -215,13 +236,14 @@ class ReliableMessage: _logger = logging.getLogger("ReliableMessage") @classmethod - def register_request_handler(cls, topic: str, handler_f): + def register_request_handler(cls, topic: str, handler_f, fl_ctx: FLContext): """Register a handler for the reliable message with this topic Args: topic: The topic of the reliable message handler_f: The callback function to handle the request in the form of handler_f(topic, request, fl_ctx) + fl_ctx: FL Context """ if not cls._enabled: raise RuntimeError("ReliableMessage is not enabled. Please call ReliableMessage.enable() to enable it") @@ -229,6 +251,13 @@ def register_request_handler(cls, topic: str, handler_f): raise TypeError(f"handler_f must be callable but {type(handler_f)}") cls._topic_to_handle[topic] = handler_f + # ReliableMessage also sends aux message directly if tx_timeout is too small + engine = fl_ctx.get_engine() + engine.register_aux_message_handler( + topic=topic, + message_handle_func=handler_f, + ) + @classmethod def _get_or_create_receiver(cls, topic: str, request: Shareable, handler_f) -> _RequestReceiver: tx_id = request.get_header(HEADER_TX_ID) @@ -250,42 +279,65 @@ def _get_or_create_receiver(cls, topic: str, request: Shareable, handler_f) -> _ @classmethod def _receive_request(cls, topic: str, request: Shareable, fl_ctx: FLContext): tx_id = request.get_header(HEADER_TX_ID) - fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, sticky=False, private=True) op = request.get_header(HEADER_OP) - topic = request.get_header(HEADER_TOPIC) + rm_topic = request.get_header(HEADER_TOPIC) + fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, sticky=False, private=True) + fl_ctx.set_prop(key=PROP_KEY_OP, value=op, sticky=False, private=True) + fl_ctx.set_prop(key=PROP_KEY_TOPIC, value=rm_topic, sticky=False, private=True) + cls.debug(fl_ctx, f"received aux msg ({topic=}) for RM request") + if op == OP_REQUEST: - handler_f = cls._topic_to_handle.get(topic) + handler_f = cls._topic_to_handle.get(rm_topic) if not handler_f: # no handler registered for this topic! - cls.error(fl_ctx, f"no handler registered for request {topic=}") + cls.error(fl_ctx, f"no handler registered for request {rm_topic=}") return make_reply(ReturnCode.TOPIC_UNKNOWN) - receiver = cls._get_or_create_receiver(topic, request, handler_f) - cls.info(fl_ctx, f"received request {topic=}") + receiver = cls._get_or_create_receiver(rm_topic, request, handler_f) + cls.debug(fl_ctx, f"received request {rm_topic=}") return receiver.process(request, fl_ctx) elif op == OP_QUERY: receiver = cls._req_receivers.get(tx_id) if not receiver: - cls.error(fl_ctx, f"received query but the request ({topic=}) is not received!") + cls.warning( + fl_ctx, f"received query but the request ({rm_topic=} {tx_id=}) is not received or already done!" + ) return _status_reply(STATUS_NOT_RECEIVED) # meaning the request wasn't received else: return receiver.process(request, fl_ctx) else: - cls.error(fl_ctx, f"received invalid op {op} for the request ({topic=})") + cls.error(fl_ctx, f"received invalid op {op} for the request ({rm_topic=})") return make_reply(rc=ReturnCode.BAD_REQUEST_DATA) @classmethod def _receive_reply(cls, topic: str, request: Shareable, fl_ctx: FLContext): tx_id = request.get_header(HEADER_TX_ID) fl_ctx.set_prop(key=PROP_KEY_TX_ID, value=tx_id, private=True, sticky=False) + cls.debug(fl_ctx, f"received aux msg ({topic=}) for RM reply") receiver = cls._reply_receivers.get(tx_id) if not receiver: cls.error(fl_ctx, "received reply but we are no longer waiting for it") else: assert isinstance(receiver, _ReplyReceiver) - cls.info(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter") + cls.debug(fl_ctx, f"received reply in {time.time()-receiver.tx_start_time} secs - set waiter") receiver.process(request) return make_reply(ReturnCode.OK) + @classmethod + def release_request_receiver(cls, receiver: _RequestReceiver, fl_ctx: FLContext): + """Release the specified _RequestReceiver from the receiver table. + This is to be called after the received request is finished. + + Args: + receiver: the _RequestReceiver to be released + fl_ctx: the FL Context + + Returns: None + + """ + with cls._tx_lock: + cls._req_receivers.pop(receiver.tx_id, None) + cls.debug(fl_ctx, f"released request receiver of TX {receiver.tx_id}") + @classmethod def enable(cls, fl_ctx: FLContext): """Enable ReliableMessage. This method can be called multiple times, but only the 1st call has effect. @@ -330,7 +382,7 @@ def _monitor_req_receivers(cls): now = time.time() for tx_id, receiver in cls._req_receivers.items(): assert isinstance(receiver, _RequestReceiver) - if receiver.rcv_time and now - receiver.rcv_time > 4 * receiver.tx_timeout: + if receiver.rcv_time and now - receiver.rcv_time > receiver.tx_timeout: cls._logger.info(f"detected expired request receiver {tx_id}") expired_receivers.append(tx_id) @@ -356,19 +408,54 @@ def shutdown(cls): @classmethod def _log_msg(cls, fl_ctx: FLContext, msg: str): + props = [] tx_id = fl_ctx.get_prop(PROP_KEY_TX_ID) if tx_id: - msg = f"[RM: {tx_id=}] {msg}" + props.append(f"rm_tx={tx_id}") + + op = fl_ctx.get_prop(PROP_KEY_OP) + if op: + props.append(f"rm_op={op}") + + topic = fl_ctx.get_prop(PROP_KEY_TOPIC) + if topic: + props.append(f"rm_topic={topic}") + + rm_ctx = "" + if props: + rm_ctx = " ".join(props) + + if rm_ctx: + msg = f"[{rm_ctx}] {msg}" return generate_log_message(fl_ctx, msg) @classmethod def info(cls, fl_ctx: FLContext, msg: str): cls._logger.info(cls._log_msg(fl_ctx, msg)) + @classmethod + def warning(cls, fl_ctx: FLContext, msg: str): + cls._logger.warning(cls._log_msg(fl_ctx, msg)) + @classmethod def error(cls, fl_ctx: FLContext, msg: str): cls._logger.error(cls._log_msg(fl_ctx, msg)) + @classmethod + def is_available(cls): + """Return whether the ReliableMessage service is available + + Returns: + + """ + if cls._shutdown_asked: + return False + + if not cls._enabled: + return False + + return True + @classmethod def debug(cls, fl_ctx: FLContext, msg: str): cls._logger.debug(cls._log_msg(fl_ctx, msg)) @@ -384,18 +471,23 @@ def send_request( abort_signal: Signal, fl_ctx: FLContext, ) -> Shareable: - """Send a reliable request. + """Send a request reliably. Args: - target: the target cell of this request - topic: topic of the request; - request: the request to be sent - per_msg_timeout: timeout when sending a message - tx_timeout: the timeout of the whole transaction - abort_signal: abort signal - fl_ctx: the FL context + target: The target cell of this request. + topic: The topic of the request. + request: The request to be sent. + per_msg_timeout (float): Number of seconds to wait for each message before timing out. + tx_timeout (float): Timeout for the entire transaction. + abort_signal (Signal): Signal to abort the request. + fl_ctx (FLContext): Context for federated learning. - Returns: reply from the peer. + Returns: + The reply from the peer. + + Note: + If `tx_timeout` is not specified or is less than or equal to `per_msg_timeout`, + the request will be sent only once without retrying. """ check_positive_number("per_msg_timeout", per_msg_timeout) @@ -459,7 +551,7 @@ def _send_request( return make_reply(ReturnCode.COMMUNICATION_ERROR) if num_tries > 0: - cls.info(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}") + cls.debug(fl_ctx, f"retry #{num_tries} sending request: {per_msg_timeout=}") ack = engine.send_aux_request( targets=[target], @@ -476,7 +568,7 @@ def _send_request( # the reply is already the result - we are done! # this could happen when we didn't get positive ack for our first request, and the result was # already produced when we did the 2nd request (this request). - cls.info(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}") + cls.debug(fl_ctx, f"C1: received result in {time.time()-receiver.tx_start_time} seconds; {rc=}") return ack # the ack is a status report - check status @@ -484,7 +576,7 @@ def _send_request( if status and status != STATUS_NOT_RECEIVED: # status should never be STATUS_NOT_RECEIVED, unless there is a bug in the receiving logic # STATUS_NOT_RECEIVED is only possible during "query" phase. - cls.info(fl_ctx, f"received status ack: {rc=} {status=}") + cls.debug(fl_ctx, f"received status ack: {rc=} {status=}") break if time.time() + cls._query_interval - receiver.tx_start_time >= tx_timeout: @@ -492,7 +584,7 @@ def _send_request( return make_reply(ReturnCode.COMMUNICATION_ERROR) # we didn't get a positive ack - wait a short time and re-send the request. - cls.info(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs") + cls.debug(fl_ctx, f"unsure the request was received ({rc=}): will retry in {cls._query_interval} secs") num_tries += 1 start = time.time() while time.time() - start < cls._query_interval: @@ -501,7 +593,7 @@ def _send_request( return make_reply(ReturnCode.TASK_ABORTED) time.sleep(0.1) - cls.info(fl_ctx, "request was received by the peer - will query for result") + cls.debug(fl_ctx, "request was received by the peer - will query for result") return cls._query_result(target, abort_signal, fl_ctx, receiver) @classmethod @@ -533,7 +625,7 @@ def _query_result( # we already received result sent by the target. # Note that we don't wait forever here - we only wait for _query_interval, so we could # check other condition and/or send query to ask for result. - cls.info(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds") + cls.debug(fl_ctx, f"C2: received result in {time.time()-receiver.tx_start_time} seconds") return receiver.result if abort_signal and abort_signal.triggered: @@ -547,7 +639,7 @@ def _query_result( # send a query. The ack of the query could be the result itself, or a status report. # Note: the ack could be the result because we failed to receive the result sent by the target earlier. num_tries += 1 - cls.info(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}") + cls.debug(fl_ctx, f"query #{num_tries}: try to get result from {target}: {per_msg_timeout=}") ack = engine.send_aux_request( targets=[target], topic=TOPIC_RELIABLE_REQUEST, @@ -555,13 +647,18 @@ def _query_result( timeout=per_msg_timeout, fl_ctx=fl_ctx, ) + + # Ignore query result if reply result is already received + if receiver.result_ready.is_set(): + return receiver.result + last_query_time = time.time() ack, rc = _extract_result(ack, target) if ack and rc not in [ReturnCode.COMMUNICATION_ERROR]: op = ack.get_header(HEADER_OP) if op == OP_REPLY: # the ack is result itself! - cls.info(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds") + cls.debug(fl_ctx, f"C3: received result in {time.time()-receiver.tx_start_time} seconds") return ack status = ack.get_header(HEADER_STATUS) @@ -573,6 +670,6 @@ def _query_result( cls.error(fl_ctx, f"peer {target} aborted processing!") return _error_reply(ReturnCode.EXECUTION_EXCEPTION, "Aborted") - cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}") + cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=} {status=} {op=}") else: - cls.info(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}") + cls.debug(fl_ctx, f"will retry query in {cls._query_interval} secs: {rc=}") diff --git a/nvflare/app_opt/xgboost/histogram_based/executor.py b/nvflare/app_opt/xgboost/histogram_based/executor.py index 9a2f67f3bc..3ae15eb48f 100644 --- a/nvflare/app_opt/xgboost/histogram_based/executor.py +++ b/nvflare/app_opt/xgboost/histogram_based/executor.py @@ -260,7 +260,7 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) - self.log_info(fl_ctx, f"server address is {self._server_address}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_address}:{xgb_fl_server_port}", "federated_world_size": self.world_size, "federated_rank": self.rank, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py index a91ff5adf5..2048cd2570 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptor_controller.py @@ -209,23 +209,15 @@ def start_controller(self, fl_ctx: FLContext): adaptor.initialize(fl_ctx) self.adaptor = adaptor - engine = fl_ctx.get_engine() - engine.register_aux_message_handler( - topic=Constant.TOPIC_XGB_REQUEST, - message_handle_func=self._process_xgb_request, - ) - engine.register_aux_message_handler( - topic=Constant.TOPIC_CLIENT_DONE, - message_handle_func=self._process_client_done, - ) - ReliableMessage.register_request_handler( topic=Constant.TOPIC_XGB_REQUEST, handler_f=self._process_xgb_request, + fl_ctx=fl_ctx, ) ReliableMessage.register_request_handler( topic=Constant.TOPIC_CLIENT_DONE, handler_f=self._process_client_done, + fl_ctx=fl_ctx, ) def _trigger_stop(self, fl_ctx: FLContext, error=None): diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index 730d19bbb8..b4750c2461 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -27,7 +27,7 @@ class Constant: CONF_KEY_NUM_ROUNDS = "num_rounds" # default component config values - CONFIG_TASK_TIMEOUT = 10 + CONFIG_TASK_TIMEOUT = 60 START_TASK_TIMEOUT = 10 XGB_SERVER_READY_TIMEOUT = 5.0 diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto index dc884f3d29..fbc2adf503 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated.proto @@ -1,26 +1,30 @@ /*! - * Copyright 2022 XGBoost contributors - * needs to match file in https://github.com/dmlc/xgboost/blob/v2.0.3/plugin/federated/federated.proto + * Copyright 2022-2023 XGBoost contributors */ syntax = "proto3"; -package xgboost.federated; +package xgboost.collective.federated; service Federated { rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} } enum DataType { - INT8 = 0; - UINT8 = 1; - INT32 = 2; - UINT32 = 3; - INT64 = 4; - UINT64 = 5; - FLOAT = 6; - DOUBLE = 7; + HALF = 0; + FLOAT = 1; + DOUBLE = 2; + LONG_DOUBLE = 3; + INT8 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + UINT8 = 8; + UINT16 = 9; + UINT32 = 10; + UINT64 = 11; } enum ReduceOperation { @@ -43,6 +47,17 @@ message AllgatherReply { bytes receive_buffer = 1; } +message AllgatherVRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherVReply { + bytes receive_buffer = 1; +} + message AllreduceRequest { // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; @@ -67,4 +82,4 @@ message BroadcastRequest { message BroadcastReply { bytes receive_buffer = 1; -} \ No newline at end of file +} diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py index fc0a379471..e69d5d5e07 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: federated.proto -# Protobuf Python Version: 4.25.0 +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -27,29 +28,33 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x11xgboost.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xbc\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12.\n\tdata_type\x18\x04 \x01(\x0e\x32\x1b.xgboost.federated.DataType\x12<\n\x10reduce_operation\x18\x05 \x01(\x0e\x32\".xgboost.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*d\n\x08\x44\x61taType\x12\x08\n\x04INT8\x10\x00\x12\t\n\x05UINT8\x10\x01\x12\t\n\x05INT32\x10\x02\x12\n\n\x06UINT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06UINT64\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\x90\x02\n\tFederated\x12U\n\tAllgather\x12#.xgboost.federated.AllgatherRequest\x1a!.xgboost.federated.AllgatherReply\"\x00\x12U\n\tAllreduce\x12#.xgboost.federated.AllreduceRequest\x1a!.xgboost.federated.AllreduceReply\"\x00\x12U\n\tBroadcast\x12#.xgboost.federated.BroadcastRequest\x1a!.xgboost.federated.BroadcastReply\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals['_DATATYPE']._serialized_start=529 - _globals['_DATATYPE']._serialized_end=629 - _globals['_REDUCEOPERATION']._serialized_start=631 - _globals['_REDUCEOPERATION']._serialized_end=725 - _globals['_ALLGATHERREQUEST']._serialized_start=38 - _globals['_ALLGATHERREQUEST']._serialized_end=116 - _globals['_ALLGATHERREPLY']._serialized_start=118 - _globals['_ALLGATHERREPLY']._serialized_end=158 - _globals['_ALLREDUCEREQUEST']._serialized_start=161 - _globals['_ALLREDUCEREQUEST']._serialized_end=349 - _globals['_ALLREDUCEREPLY']._serialized_start=351 - _globals['_ALLREDUCEREPLY']._serialized_end=391 - _globals['_BROADCASTREQUEST']._serialized_start=393 - _globals['_BROADCASTREQUEST']._serialized_end=485 - _globals['_BROADCASTREPLY']._serialized_start=487 - _globals['_BROADCASTREPLY']._serialized_end=527 - _globals['_FEDERATED']._serialized_start=728 - _globals['_FEDERATED']._serialized_end=1000 + _globals['_DATATYPE']._serialized_start=687 + _globals['_DATATYPE']._serialized_end=837 + _globals['_REDUCEOPERATION']._serialized_start=839 + _globals['_REDUCEOPERATION']._serialized_end=933 + _globals['_ALLGATHERREQUEST']._serialized_start=49 + _globals['_ALLGATHERREQUEST']._serialized_end=127 + _globals['_ALLGATHERREPLY']._serialized_start=129 + _globals['_ALLGATHERREPLY']._serialized_end=169 + _globals['_ALLGATHERVREQUEST']._serialized_start=171 + _globals['_ALLGATHERVREQUEST']._serialized_end=250 + _globals['_ALLGATHERVREPLY']._serialized_start=252 + _globals['_ALLGATHERVREPLY']._serialized_end=293 + _globals['_ALLREDUCEREQUEST']._serialized_start=296 + _globals['_ALLREDUCEREQUEST']._serialized_end=506 + _globals['_ALLREDUCEREPLY']._serialized_start=508 + _globals['_ALLREDUCEREPLY']._serialized_end=548 + _globals['_BROADCASTREQUEST']._serialized_start=550 + _globals['_BROADCASTREQUEST']._serialized_end=642 + _globals['_BROADCASTREPLY']._serialized_start=644 + _globals['_BROADCASTREPLY']._serialized_end=684 + _globals['_FEDERATED']._serialized_start=936 + _globals['_FEDERATED']._serialized_end=1386 # @@protoc_insertion_point(module_scope) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi index e82022b080..7dc3e6dde1 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2.pyi @@ -7,14 +7,18 @@ DESCRIPTOR: _descriptor.FileDescriptor class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () + HALF: _ClassVar[DataType] + FLOAT: _ClassVar[DataType] + DOUBLE: _ClassVar[DataType] + LONG_DOUBLE: _ClassVar[DataType] INT8: _ClassVar[DataType] - UINT8: _ClassVar[DataType] + INT16: _ClassVar[DataType] INT32: _ClassVar[DataType] - UINT32: _ClassVar[DataType] INT64: _ClassVar[DataType] + UINT8: _ClassVar[DataType] + UINT16: _ClassVar[DataType] + UINT32: _ClassVar[DataType] UINT64: _ClassVar[DataType] - FLOAT: _ClassVar[DataType] - DOUBLE: _ClassVar[DataType] class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () @@ -24,14 +28,18 @@ class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): BITWISE_AND: _ClassVar[ReduceOperation] BITWISE_OR: _ClassVar[ReduceOperation] BITWISE_XOR: _ClassVar[ReduceOperation] +HALF: DataType +FLOAT: DataType +DOUBLE: DataType +LONG_DOUBLE: DataType INT8: DataType -UINT8: DataType +INT16: DataType INT32: DataType -UINT32: DataType INT64: DataType +UINT8: DataType +UINT16: DataType +UINT32: DataType UINT64: DataType -FLOAT: DataType -DOUBLE: DataType MAX: ReduceOperation MIN: ReduceOperation SUM: ReduceOperation @@ -55,6 +63,22 @@ class AllgatherReply(_message.Message): receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... +class AllgatherVRequest(_message.Message): + __slots__ = ("sequence_number", "rank", "send_buffer") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] + sequence_number: int + rank: int + send_buffer: bytes + def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... + +class AllgatherVReply(_message.Message): + __slots__ = ("receive_buffer",) + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... + class AllreduceRequest(_message.Message): __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py index 1a6ab35b98..549d0e4ffc 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/federated_pb2_grpc.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: federated.proto +# Protobuf Python Version: 4.25.1 # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc -from .federated_pb2 import AllgatherReply, AllreduceReply, AllgatherRequest, AllreduceRequest, BroadcastRequest, BroadcastReply - +import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as federated__pb2 class FederatedStub(object): """Missing associated documentation comment in .proto file.""" @@ -28,19 +32,24 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Allgather = channel.unary_unary( - '/xgboost.federated.Federated/Allgather', - request_serializer=AllgatherRequest.SerializeToString, - response_deserializer=AllgatherReply.FromString, + '/xgboost.collective.federated.Federated/Allgather', + request_serializer=federated__pb2.AllgatherRequest.SerializeToString, + response_deserializer=federated__pb2.AllgatherReply.FromString, + ) + self.AllgatherV = channel.unary_unary( + '/xgboost.collective.federated.Federated/AllgatherV', + request_serializer=federated__pb2.AllgatherVRequest.SerializeToString, + response_deserializer=federated__pb2.AllgatherVReply.FromString, ) self.Allreduce = channel.unary_unary( - '/xgboost.federated.Federated/Allreduce', - request_serializer=AllreduceRequest.SerializeToString, - response_deserializer=AllreduceReply.FromString, + '/xgboost.collective.federated.Federated/Allreduce', + request_serializer=federated__pb2.AllreduceRequest.SerializeToString, + response_deserializer=federated__pb2.AllreduceReply.FromString, ) self.Broadcast = channel.unary_unary( - '/xgboost.federated.Federated/Broadcast', - request_serializer=BroadcastRequest.SerializeToString, - response_deserializer=BroadcastReply.FromString, + '/xgboost.collective.federated.Federated/Broadcast', + request_serializer=federated__pb2.BroadcastRequest.SerializeToString, + response_deserializer=federated__pb2.BroadcastReply.FromString, ) @@ -53,6 +62,12 @@ def Allgather(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def AllgatherV(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Allreduce(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -70,22 +85,27 @@ def add_FederatedServicer_to_server(servicer, server): rpc_method_handlers = { 'Allgather': grpc.unary_unary_rpc_method_handler( servicer.Allgather, - request_deserializer=AllgatherRequest.FromString, - response_serializer=AllgatherReply.SerializeToString, + request_deserializer=federated__pb2.AllgatherRequest.FromString, + response_serializer=federated__pb2.AllgatherReply.SerializeToString, + ), + 'AllgatherV': grpc.unary_unary_rpc_method_handler( + servicer.AllgatherV, + request_deserializer=federated__pb2.AllgatherVRequest.FromString, + response_serializer=federated__pb2.AllgatherVReply.SerializeToString, ), 'Allreduce': grpc.unary_unary_rpc_method_handler( servicer.Allreduce, - request_deserializer=AllreduceRequest.FromString, - response_serializer=AllreduceReply.SerializeToString, + request_deserializer=federated__pb2.AllreduceRequest.FromString, + response_serializer=federated__pb2.AllreduceReply.SerializeToString, ), 'Broadcast': grpc.unary_unary_rpc_method_handler( servicer.Broadcast, - request_deserializer=BroadcastRequest.FromString, - response_serializer=BroadcastReply.SerializeToString, + request_deserializer=federated__pb2.BroadcastRequest.FromString, + response_serializer=federated__pb2.BroadcastReply.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( - 'xgboost.federated.Federated', rpc_method_handlers) + 'xgboost.collective.federated.Federated', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -104,9 +124,26 @@ def Allgather(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allgather', - AllgatherRequest.SerializeToString, - AllgatherReply.FromString, + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allgather', + federated__pb2.AllgatherRequest.SerializeToString, + federated__pb2.AllgatherReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AllgatherV(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/AllgatherV', + federated__pb2.AllgatherVRequest.SerializeToString, + federated__pb2.AllgatherVReply.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -121,9 +158,9 @@ def Allreduce(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allreduce', - AllreduceRequest.SerializeToString, - AllreduceReply.FromString, + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allreduce', + federated__pb2.AllreduceRequest.SerializeToString, + federated__pb2.AllreduceReply.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -138,8 +175,8 @@ def Broadcast(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Broadcast', - BroadcastRequest.SerializeToString, - BroadcastReply.FromString, + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Broadcast', + federated__pb2.BroadcastRequest.SerializeToString, + federated__pb2.BroadcastReply.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh b/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh index 10afcf5b3b..f174f5d30f 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh +++ b/nvflare/app_opt/xgboost/histogram_based_v2/proto/gen_proto.sh @@ -1 +1,6 @@ +#!/usr/bin/env sh +# Install grpcio-tools: +# pip install grpcio-tools +# or +# mamba install grpcio-tools python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpc_python_out=. federated.proto diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/client_runner.py index e67c9d7868..dcd2c1fb92 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/client_runner.py @@ -161,7 +161,7 @@ def run(self, ctx: dict): self.logger.info(f"server address is {self._server_addr}") communicator_env = { - "xgboost_communicator": "federated", + "dmlc_communicator": "federated", "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/server_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/server_runner.py index fc409e380c..9383e3268f 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/server_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/server_runner.py @@ -32,10 +32,7 @@ def run(self, ctx: dict): _world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE, None) self._stopped = False - xgb_federated.run_federated_server( - port=_port, - world_size=_world_size, - ) + xgb_federated.run_federated_server(_world_size, _port) self._stopped = True def stop(self):