diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index a47d3b8131..327ec23d0e 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -236,6 +236,7 @@ class ServerCommandKey(object): CLIENTS = "clients" COLLECTOR = "collector" TURN_TO_COLD = "__turn_to_cold__" + REASON = "reason" class FedEventHeader(object): diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 071e0a3585..7b4d5475ca 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -80,7 +80,7 @@ def _get_client_task(target, task: Task): class _DeadClientStatus: def __init__(self): self.report_time = time.time() - self.dead_time = None + self.death_time = None class Controller(Responder, ControllerSpec, ABC): @@ -356,8 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): """ # record the report and to be used by the task monitor with self._dead_clients_lock: - self.log_warning(fl_ctx, f"received dead job report for client {client_name}") - if not self._dead_clients.get(client_name): + if self._dead_clients.get(client_name): + # already on watch list + self.log_warning(fl_ctx, f"discarded dead job report for client {client_name}: already on watch list") + else: self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list") self._dead_clients[client_name] = _DeadClientStatus() @@ -856,12 +858,36 @@ def relay_and_wait( self.wait_for_task(task, abort_signal) def _monitor_tasks(self): + grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0) while not self._all_done: - clients_all_dead = self._check_dead_clients() - if not clients_all_dead: + self._determine_dead_clients(grace_period) + if self._all_clients_dead(): + with self._engine.new_context() as fl_ctx: + self.system_panic("All clients are dead", fl_ctx) + return + else: self._check_tasks() time.sleep(self._task_check_period) + def _determine_dead_clients(self, grace_period): + if not self._dead_clients: + return + + now = time.time() + with self._dead_clients_lock: + for client_name, status in self._dead_clients.items(): + if status.death_time: + # already dead + continue + + if now - status.report_time < grace_period: + # this report is still fresh - consider the client to be still alive + continue + + # consider client dead + status.death_time = now + self.logger.error(f"Client {client_name} is deemed dead!") + def _check_tasks(self): with self._controller_lock: self._do_check_tasks() @@ -897,7 +923,7 @@ def _do_check_tasks(self): # check whether clients that the task is waiting are all dead dead_clients = self._get_task_dead_clients(task) if dead_clients: - self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT") + self.logger.info(f"clients {dead_clients} dead - set task {task.name} to TIMEOUT") task.completion_status = TaskCompletionStatus.CLIENT_DEAD exit_tasks.append(task) continue @@ -949,22 +975,21 @@ def _get_task_dead_clients(self, task: Task): return None dead_clients = [] - with self._dead_clients_lock: - for target in task.targets: - ct = _get_client_task(target, task) - if ct is not None and ct.result_received_time: - # response has been received from this client - continue - - # either we have not sent the task to this client or we have not received response - # is the client already dead? - if self._client_still_alive(target): - # this client is still alive - # we let the task continue its course since we still have live clients - return None - else: - # this client is dead - remember it - dead_clients.append(target) + for target in task.targets: + ct = _get_client_task(target, task) + if ct is not None and ct.result_received_time: + # response has been received from this client + continue + + # either we have not sent the task to this client or we have not received response + # is the client already dead? + if self.get_client_death_time(target): + # this client is dead - remember it + dead_clients.append(target) + else: + # this client is still alive + # we let the task continue its course since we still have live clients + return None return dead_clients @@ -993,46 +1018,18 @@ def wait_for_task(self, task: Task, abort_signal: Signal): break time.sleep(self._task_check_period) - def _check_dead_clients(self): + def _all_clients_dead(self): if self._engine: clients = self._engine.get_clients() - dead_clients = [] - with self._dead_clients_lock: - for client in clients: - if not self._client_still_alive(client.name): - dead_clients.append(client.name) - - if dead_clients and len(clients) == len(dead_clients): - with self._engine.new_context() as fl_ctx: - self.system_panic("All clients are dead", fl_ctx) - return True - return False - - def _client_still_alive(self, client_name): - now = time.time() - status = self._dead_clients.get(client_name, None) - grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0) - - if not status: - # this client is still alive - return True - - assert isinstance(status, _DeadClientStatus) - if status.dead_time: - return False - - if now - status.report_time < grace_period: - # this report is still fresh - consider the client to be still alive + for client in clients: + if not self.get_client_death_time(client.name): + # this client is still alive + return False return True - - # consider client dead - status.dead_time = now - self.logger.error(f"Client {client_name} is deemed dead!") - self.client_is_dead(client_name) return False def get_client_death_time(self, client_name: str): - """Get the time that the client was deemed dead + """Get the time that the client was deemed dead/disconnected Args: client_name: name of the client @@ -1042,18 +1039,17 @@ def get_client_death_time(self, client_name: str): """ status = self._dead_clients.get(client_name) if status: - assert isinstance(status, _DeadClientStatus) - return status.dead_time + return status.death_time return None - def process_job_heartbeat(self, fl_ctx: FLContext): + def process_job_heartbeat(self, fl_ctx: FLContext, reason: str): peer_ctx = fl_ctx.get_peer_context() assert isinstance(peer_ctx, FLContext) client_name = peer_ctx.get_identity_name() with self._dead_clients_lock: if client_name in self._dead_clients: - self.log_info(fl_ctx, f"Client {client_name} is removed from watch list") + self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason=}") status = self._dead_clients.pop(client_name) - if status.dead_time: - self.log_info(fl_ctx, f"Client {client_name} is revived") + if status.death_time: + self.log_info(fl_ctx, f"Client {client_name} is revived: {reason=}") self.client_is_revived(client_name) diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index 955dcc6ae0..6816d002be 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -77,11 +77,12 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext): pass @abstractmethod - def process_job_heartbeat(self, fl_ctx: FLContext): + def process_job_heartbeat(self, fl_ctx: FLContext, reason: str): """Called by the Engine to handle heartbeat received from clients. Args: fl_ctx: the FLContext + reason: reason of the HB Returns: None diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index d1e3d80d70..e45f6df73b 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -139,15 +139,15 @@ def start_controller(self, fl_ctx: FLContext): self._last_client = None def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]: - if len(self._participating_clients) <= 1: - self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx) - return None - active_clients_map = {} for t in self._participating_clients: if not self.get_client_death_time(t.name): active_clients_map[t.name] = t + if len(active_clients_map) <= 1: + self.system_panic(f"Not enough client sites (active_clients={len(active_clients_map)}).", fl_ctx) + return None + if isinstance(self._order, list): targets = [] for c_name in self._order: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 18f71e5259..86a793545f 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -35,7 +35,6 @@ WorkspaceConstants, ) from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable from nvflare.apis.workspace import Workspace from nvflare.fuel.common.exit_codes import ProcessExitCode from nvflare.fuel.f3.cellnet.cell import Cell, Message @@ -234,7 +233,7 @@ def fl_shutdown(self): self.shutdown = True start = time.time() while self.client_manager.clients: - # Wait for the clients to shutdown and quite first. + # Wait for the clients to shut down and quite first. time.sleep(0.1) if time.time() - start > self.shutdown_period: self.logger.info("There are still clients connected. But shutdown the server after timeout.") @@ -580,26 +579,17 @@ def _sync_client_jobs(self, request, client_token): # this is a dict: token => nvflare.apis.client.Client client = participating_clients.get(client_token, None) if client: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "missing job on client") return jobs_need_abort - def _notify_dead_job(self, client, job_id: str): + def _notify_dead_job(self, client, job_id: str, reason: str): try: - with self.engine.lock: - shareable = Shareable() - shareable.set_header(ServerCommandKey.FL_CLIENT, client.name) - fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id]) - request = new_cell_message({}, fobs.dumps(shareable)) - self.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.SERVER_COMMAND, - topic=ServerCommandNames.HANDLE_DEAD_JOB, - message=request, - optional=True, - ) - except Exception: - self.logger.info("Could not connect to server runner process") + self.engine.notify_dead_job(job_id, client.name, reason) + except Exception as ex: + self.logger.info( + f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}" + ) def notify_dead_client(self, client): """Called to do further processing of the dead client @@ -618,7 +608,7 @@ def notify_dead_client(self, client): assert isinstance(process_info, dict) participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None) if participating_clients and client.token in participating_clients: - self._notify_dead_job(client, job_id) + self._notify_dead_job(client, job_id, "client dead") def start_run(self, job_id, run_root, conf, args, snapshot): # Create the FL Engine diff --git a/nvflare/private/fed/server/server_commands.py b/nvflare/private/fed/server/server_commands.py index d74f94056d..358712614b 100644 --- a/nvflare/private/fed/server/server_commands.py +++ b/nvflare/private/fed/server/server_commands.py @@ -264,6 +264,8 @@ def process(self, data: Shareable, fl_ctx: FLContext): """ client_name = data.get_header(ServerCommandKey.FL_CLIENT) + reason = data.get_header(ServerCommandKey.REASON) + self.logger.warning(f"received dead job notification: {client_name=}; {reason=}") server_runner = fl_ctx.get_prop(FLContextKey.RUNNER) if server_runner: server_runner.handle_dead_job(client_name, fl_ctx) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index 17be86c40c..b06f5bb07d 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -583,6 +583,19 @@ def update_job_run_status(self): message=request, ) + def notify_dead_job(self, job_id: str, client_name: str, reason: str): + shareable = Shareable() + shareable.set_header(ServerCommandKey.FL_CLIENT, client_name) + shareable.set_header(ServerCommandKey.REASON, reason) + self.send_command_to_child_runner_process( + job_id=job_id, + command_name=ServerCommandNames.HANDLE_DEAD_JOB, + command_data=shareable, + timeout=0.0, + optional=True, + ) + self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}") + def send_command_to_child_runner_process( self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False ): @@ -594,7 +607,7 @@ def send_command_to_child_runner_process( targets=fqcn, channel=CellChannel.SERVER_COMMAND, topic=command_name, - request=request, + message=request, optional=optional, ) return None diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 43e6685554..c069ce9c3b 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -329,7 +329,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005): return "", "", None if self.current_wf.responder: - self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx) + self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="getTask") task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx) @@ -371,7 +371,6 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): try: if self.current_wf is None: return - self.current_wf.responder.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx) except Exception as e: self.log_exception( @@ -445,7 +444,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul return if self.current_wf.responder: - self.current_wf.responder.process_job_heartbeat(fl_ctx) + self.current_wf.responder.process_job_heartbeat(fl_ctx, "submitTask") wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None) if wf_id is not None and wf_id != self.current_wf.id: @@ -508,7 +507,7 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex self.log_debug(fl_ctx, "received client job_heartbeat") with self.wf_lock: if self.current_wf and self.current_wf.responder: - self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx) + self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="jobHeartbeat") return make_reply(ReturnCode.OK) def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: @@ -525,7 +524,7 @@ def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) return make_reply(ReturnCode.TASK_UNKNOWN) if self.current_wf.responder: - self.current_wf.responder.process_job_heartbeat(fl_ctx) + self.current_wf.responder.process_job_heartbeat(fl_ctx, "taskCheck") # filter task result task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx) diff --git a/nvflare/private/fed/server/sys_cmd.py b/nvflare/private/fed/server/sys_cmd.py index 4ae643d852..eea16434b9 100644 --- a/nvflare/private/fed/server/sys_cmd.py +++ b/nvflare/private/fed/server/sys_cmd.py @@ -68,6 +68,14 @@ def get_spec(self): authz_func=self.authorize_server_operation, visible=True, ), + CommandSpec( + name="dead", + description="send dead client msg to SJ", + usage="dead ", + handler_func=self.dead_client, + authz_func=self.must_be_project_admin, + visible=False, + ), ], ) @@ -156,3 +164,13 @@ def report_resources(self, conn: Connection, args: List[str]): table = conn.append_table(["Sites", "Resources"]) for k, v in site_resources.items(): table.add_row([str(k), str(v)]) + + def dead_client(self, conn: Connection, args: List[str]): + if len(args) != 3: + conn.append_error(f"Usage: {args[0]} client_name job_id") + return + client_name = args[1] + job_id = args[2] + engine = conn.app_ctx + engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}") + conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}")