diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index f2445ef55a..c76c2f70b9 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -163,6 +163,7 @@ class ReservedTopic(object): ABORT_ASK = "__abort_task__" AUX_COMMAND = "__aux_command__" JOB_HEART_BEAT = "__job_heartbeat__" + TASK_CHECK = "__task_check__" class AdminCommandNames(object): diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index c0887f3047..00ff367e33 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -335,6 +335,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if not self._dead_client_reports.get(client_name): self._dead_client_reports[client_name] = time.time() + def process_task_check(self, task_id: str, fl_ctx: FLContext): + with self._task_lock: + # task_id is the uuid associated with the client_task + return self._client_task_map.get(task_id, None) + def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): """Called to process a submission from one client. diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index a411b30bd4..acb5c6b62e 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -63,6 +63,19 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul """ pass + @abstractmethod + def process_task_check(self, task_id: str, fl_ctx: FLContext): + """Called by the Engine to check whether a specified task still exists. + + Args: + task_id: the id of the task + fl_ctx: the FLContext + + Returns: the ClientTask object if exists; None otherwise + + """ + pass + @abstractmethod def handle_dead_job(self, client_name: str, fl_ctx: FLContext): """Called by the Engine to handle the case that the job on the client is dead. diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 15a684d47e..9184926186 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -29,6 +29,10 @@ from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector +_TASK_CHECK_RESULT_OK = 0 +_TASK_CHECK_RESULT_TRY_AGAIN = 1 +_TASK_CHECK_RESULT_TASK_GONE = 2 + class ClientRunnerConfig(object): def __init__( @@ -107,7 +111,8 @@ def __init__( self.task_lock = threading.Lock() self.end_run_fired = False self.end_run_lock = threading.Lock() - + self.task_check_timeout = 5.0 + self.task_check_interval = 5.0 self._register_aux_message_handler(engine) def _register_aux_message_handler(self, engine): @@ -473,19 +478,108 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): if cookie_jar: task_reply.set_cookie_jar(cookie_jar) - reply_sent = self.engine.send_task_result(task_reply, fl_ctx) - if reply_sent: - self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id)) - else: - self.log_error( - fl_ctx, - "failed to send result to server for task: name={}, id={}".format(task.name, task.task_id), - ) + self._send_task_result(task_reply, task.task_id, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT") self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx) return task_fetch_interval, True + def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext): + try_count = 1 + while True: + self.log_info(fl_ctx, f"try #{try_count}: sending task result to server") + + if self.asked_to_stop: + self.log_info(fl_ctx, "job aborted: stopped trying to send result") + return False + + try_count += 1 + rc = self._try_send_result_once(result, task_id, fl_ctx) + + if rc == _TASK_CHECK_RESULT_OK: + return True + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return False + else: + # retry + time.sleep(self.task_check_interval) + + def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext): + # wait until server is ready to receive + while True: + if self.asked_to_stop: + return _TASK_CHECK_RESULT_TASK_GONE + + rc = self._check_task_once(task_id, fl_ctx) + if rc == _TASK_CHECK_RESULT_OK: + break + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return rc + else: + # try again + time.sleep(self.task_check_interval) + + # try to send the result + self.log_info(fl_ctx, "start to send task result to server") + reply_sent = self.engine.send_task_result(result, fl_ctx) + if reply_sent: + self.log_info(fl_ctx, "task result sent to server") + return _TASK_CHECK_RESULT_OK + else: + self.log_error(fl_ctx, "failed to send task result to server - will try again") + return _TASK_CHECK_RESULT_TRY_AGAIN + + def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int: + """This method checks whether the server is still waiting for the specified task. + The real reason for this method is to fight against unstable network connections. + We try to make sure that when we send task result to the server, the connection is available. + If the task check succeeds, then the network connection is likely to be available. + Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out). + + Args: + task_id: + fl_ctx: + + Returns: + + """ + self.log_info(fl_ctx, "checking task ...") + task_check_req = Shareable() + task_check_req.set_header(ReservedKey.TASK_ID, task_id) + resp = self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.TASK_CHECK, + request=task_check_req, + timeout=self.task_check_timeout, + fl_ctx=fl_ctx, + optional=True, + ) + if resp and isinstance(resp, dict): + reply = resp.get(FQCN.ROOT_SERVER) + if not isinstance(reply, Shareable): + self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + + rc = reply.get_return_code() + if rc == ReturnCode.OK: + return _TASK_CHECK_RESULT_OK + elif rc == ReturnCode.COMMUNICATION_ERROR: + self.log_error(fl_ctx, f"failed task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.SERVER_NOT_READY: + self.log_error(fl_ctx, f"server rejected task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.TASK_UNKNOWN: + self.log_error(fl_ctx, f"task no longer exists on server: {rc}") + return _TASK_CHECK_RESULT_TASK_GONE + else: + # this should never happen + self.log_error(fl_ctx, f"programming error: received {rc} from server") + return _TASK_CHECK_RESULT_OK # try to push the result regardless + else: + self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + def run(self, app_root, args): self.init_run(app_root, args) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 32d0bb0b3e..3a0a66d267 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -65,6 +65,7 @@ def __init__( cell: Cell = None, client_register_interval=2, timeout=5.0, + maint_msg_timeout=5.0, ): """To init the Communicator. @@ -85,7 +86,7 @@ def __init__( self.compression = compression self.client_register_interval = client_register_interval self.timeout = timeout - + self.maint_msg_timeout = maint_msg_timeout self.logger = logging.getLogger(self.__class__.__name__) def client_registration(self, client_name, servers, project_name): @@ -129,7 +130,7 @@ def client_registration(self, client_name, servers, project_name): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Register, request=login_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -297,7 +298,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Quit, request=quit_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -335,9 +336,13 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.HEART_BEAT, request=heartbeat_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) + + if return_code != ReturnCode.OK: + self.logger.error(f"heartbeat error: {return_code}") + if return_code == ReturnCode.UNAUTHENTICATED: unauthenticated = result.get_header(MessageHeaderKey.ERROR) raise FLCommunicationError("error:client_quit " + unauthenticated) diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 6817d34bb7..90f3d3b6ef 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -104,6 +104,7 @@ def __init__( cell=cell, client_register_interval=client_args.get("client_register_interval", 2.0), timeout=client_args.get("communication_timeout", 30.0), + maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0), ) self.secure_train = secure_train diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 2d17182511..0a0f3ac704 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -106,6 +106,8 @@ def _register_aux_message_handler(self, engine): topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat ) + engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check) + def _execute_run(self): while self.current_wf_index < len(self.config.workflows): wf = self.config.workflows[self.current_wf_index] @@ -500,6 +502,28 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex self.log_info(fl_ctx, "received client job_heartbeat aux request") return make_reply(ReturnCode.OK) + def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + task_id = request.get_header(ReservedHeaderKey.TASK_ID) + if not task_id: + self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request") + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + self.log_info(fl_ctx, f"received task_check on task {task_id}") + + with self.wf_lock: + if self.current_wf is None or self.current_wf.responder is None: + self.log_info(fl_ctx, "no current workflow - dropped task_check.") + return make_reply(ReturnCode.TASK_UNKNOWN) + + # filter task result + task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx) + if task: + self.log_info(fl_ctx, f"task {task_id} is still good") + return make_reply(ReturnCode.OK) + else: + self.log_info(fl_ctx, f"task {task_id} is not found") + return make_reply(ReturnCode.TASK_UNKNOWN) + def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False): self.status = "done" self.abort_signal.trigger(value=True)