Skip to content

Commit

Permalink
[2.3] Improve dead client handling (#2501)
Browse files Browse the repository at this point in the history
* improve dead client handling

* remove unused import

* fix docstring

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
3 people authored Apr 17, 2024
1 parent ed670ac commit 16730ec
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 92 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class ServerCommandKey(object):
CLIENTS = "clients"
COLLECTOR = "collector"
TURN_TO_COLD = "__turn_to_cold__"
REASON = "reason"


class FedEventHeader(object):
Expand Down
120 changes: 58 additions & 62 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_client_task(target, task: Task):
class _DeadClientStatus:
def __init__(self):
self.report_time = time.time()
self.dead_time = None
self.death_time = None


class Controller(Responder, ControllerSpec, ABC):
Expand Down Expand Up @@ -356,8 +356,10 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""
# record the report and to be used by the task monitor
with self._dead_clients_lock:
self.log_warning(fl_ctx, f"received dead job report for client {client_name}")
if not self._dead_clients.get(client_name):
if self._dead_clients.get(client_name):
# already on watch list
self.log_warning(fl_ctx, f"discarded dead job report for client {client_name}: already on watch list")
else:
self.log_warning(fl_ctx, f"client {client_name} is placed on dead client watch list")
self._dead_clients[client_name] = _DeadClientStatus()

Expand Down Expand Up @@ -856,12 +858,36 @@ def relay_and_wait(
self.wait_for_task(task, abort_signal)

def _monitor_tasks(self):
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0)
while not self._all_done:
clients_all_dead = self._check_dead_clients()
if not clients_all_dead:
self._determine_dead_clients(grace_period)
if self._all_clients_dead():
with self._engine.new_context() as fl_ctx:
self.system_panic("All clients are dead", fl_ctx)
return
else:
self._check_tasks()
time.sleep(self._task_check_period)

def _determine_dead_clients(self, grace_period):
if not self._dead_clients:
return

now = time.time()
with self._dead_clients_lock:
for client_name, status in self._dead_clients.items():
if status.death_time:
# already dead
continue

if now - status.report_time < grace_period:
# this report is still fresh - consider the client to be still alive
continue

# consider client dead
status.death_time = now
self.logger.error(f"Client {client_name} is deemed dead!")

def _check_tasks(self):
with self._controller_lock:
self._do_check_tasks()
Expand Down Expand Up @@ -897,7 +923,7 @@ def _do_check_tasks(self):
# check whether clients that the task is waiting are all dead
dead_clients = self._get_task_dead_clients(task)
if dead_clients:
self.logger.info(f"client {dead_clients} is dead - set task {task.name} to TIMEOUT")
self.logger.info(f"clients {dead_clients} dead - set task {task.name} to TIMEOUT")
task.completion_status = TaskCompletionStatus.CLIENT_DEAD
exit_tasks.append(task)
continue
Expand Down Expand Up @@ -949,22 +975,21 @@ def _get_task_dead_clients(self, task: Task):
return None

dead_clients = []
with self._dead_clients_lock:
for target in task.targets:
ct = _get_client_task(target, task)
if ct is not None and ct.result_received_time:
# response has been received from this client
continue

# either we have not sent the task to this client or we have not received response
# is the client already dead?
if self._client_still_alive(target):
# this client is still alive
# we let the task continue its course since we still have live clients
return None
else:
# this client is dead - remember it
dead_clients.append(target)
for target in task.targets:
ct = _get_client_task(target, task)
if ct is not None and ct.result_received_time:
# response has been received from this client
continue

# either we have not sent the task to this client or we have not received response
# is the client already dead?
if self.get_client_death_time(target):
# this client is dead - remember it
dead_clients.append(target)
else:
# this client is still alive
# we let the task continue its course since we still have live clients
return None

return dead_clients

Expand Down Expand Up @@ -993,46 +1018,18 @@ def wait_for_task(self, task: Task, abort_signal: Signal):
break
time.sleep(self._task_check_period)

def _check_dead_clients(self):
def _all_clients_dead(self):
if self._engine:
clients = self._engine.get_clients()
dead_clients = []
with self._dead_clients_lock:
for client in clients:
if not self._client_still_alive(client.name):
dead_clients.append(client.name)

if dead_clients and len(clients) == len(dead_clients):
with self._engine.new_context() as fl_ctx:
self.system_panic("All clients are dead", fl_ctx)
return True
return False

def _client_still_alive(self, client_name):
now = time.time()
status = self._dead_clients.get(client_name, None)
grace_period = ConfigService.get_float_var(name=_CONFIG_VAR_DEAD_CLIENT_GRACE_PERIOD, default=60.0)

if not status:
# this client is still alive
return True

assert isinstance(status, _DeadClientStatus)
if status.dead_time:
return False

if now - status.report_time < grace_period:
# this report is still fresh - consider the client to be still alive
for client in clients:
if not self.get_client_death_time(client.name):
# this client is still alive
return False
return True

# consider client dead
status.dead_time = now
self.logger.error(f"Client {client_name} is deemed dead!")
self.client_is_dead(client_name)
return False

def get_client_death_time(self, client_name: str):
"""Get the time that the client was deemed dead
"""Get the time that the client was deemed dead/disconnected
Args:
client_name: name of the client
Expand All @@ -1042,18 +1039,17 @@ def get_client_death_time(self, client_name: str):
"""
status = self._dead_clients.get(client_name)
if status:
assert isinstance(status, _DeadClientStatus)
return status.dead_time
return status.death_time
return None

def process_job_heartbeat(self, fl_ctx: FLContext):
def process_job_heartbeat(self, fl_ctx: FLContext, reason: str):
peer_ctx = fl_ctx.get_peer_context()
assert isinstance(peer_ctx, FLContext)
client_name = peer_ctx.get_identity_name()
with self._dead_clients_lock:
if client_name in self._dead_clients:
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list")
self.log_info(fl_ctx, f"Client {client_name} is removed from watch list: {reason=}")
status = self._dead_clients.pop(client_name)
if status.dead_time:
self.log_info(fl_ctx, f"Client {client_name} is revived")
if status.death_time:
self.log_info(fl_ctx, f"Client {client_name} is revived: {reason=}")
self.client_is_revived(client_name)
3 changes: 2 additions & 1 deletion nvflare/apis/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,12 @@ def process_task_check(self, task_id: str, fl_ctx: FLContext):
pass

@abstractmethod
def process_job_heartbeat(self, fl_ctx: FLContext):
def process_job_heartbeat(self, fl_ctx: FLContext, reason: str):
"""Called by the Engine to handle heartbeat received from clients.
Args:
fl_ctx: the FLContext
reason: reason of the HB
Returns: None
Expand Down
8 changes: 4 additions & 4 deletions nvflare/app_common/workflows/cyclic_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ def start_controller(self, fl_ctx: FLContext):
self._last_client = None

def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
if len(self._participating_clients) <= 1:
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
return None

active_clients_map = {}
for t in self._participating_clients:
if not self.get_client_death_time(t.name):
active_clients_map[t.name] = t

if len(active_clients_map) <= 1:
self.system_panic(f"Not enough client sites (active_clients={len(active_clients_map)}).", fl_ctx)
return None

if isinstance(self._order, list):
targets = []
for c_name in self._order:
Expand Down
28 changes: 9 additions & 19 deletions nvflare/private/fed/server/fed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
WorkspaceConstants,
)
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.exit_codes import ProcessExitCode
from nvflare.fuel.f3.cellnet.cell import Cell, Message
Expand Down Expand Up @@ -234,7 +233,7 @@ def fl_shutdown(self):
self.shutdown = True
start = time.time()
while self.client_manager.clients:
# Wait for the clients to shutdown and quite first.
# Wait for the clients to shut down and quite first.
time.sleep(0.1)
if time.time() - start > self.shutdown_period:
self.logger.info("There are still clients connected. But shutdown the server after timeout.")
Expand Down Expand Up @@ -580,26 +579,17 @@ def _sync_client_jobs(self, request, client_token):
# this is a dict: token => nvflare.apis.client.Client
client = participating_clients.get(client_token, None)
if client:
self._notify_dead_job(client, job_id)
self._notify_dead_job(client, job_id, "missing job on client")

return jobs_need_abort

def _notify_dead_job(self, client, job_id: str):
def _notify_dead_job(self, client, job_id: str, reason: str):
try:
with self.engine.lock:
shareable = Shareable()
shareable.set_header(ServerCommandKey.FL_CLIENT, client.name)
fqcn = FQCN.join([FQCN.ROOT_SERVER, job_id])
request = new_cell_message({}, fobs.dumps(shareable))
self.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.SERVER_COMMAND,
topic=ServerCommandNames.HANDLE_DEAD_JOB,
message=request,
optional=True,
)
except Exception:
self.logger.info("Could not connect to server runner process")
self.engine.notify_dead_job(job_id, client.name, reason)
except Exception as ex:
self.logger.info(
f"Failed to notify_dead_job to runner process of job {job_id}: {secure_format_exception(ex)}"
)

def notify_dead_client(self, client):
"""Called to do further processing of the dead client
Expand All @@ -618,7 +608,7 @@ def notify_dead_client(self, client):
assert isinstance(process_info, dict)
participating_clients = process_info.get(RunProcessKey.PARTICIPANTS, None)
if participating_clients and client.token in participating_clients:
self._notify_dead_job(client, job_id)
self._notify_dead_job(client, job_id, "client dead")

def start_run(self, job_id, run_root, conf, args, snapshot):
# Create the FL Engine
Expand Down
2 changes: 2 additions & 0 deletions nvflare/private/fed/server/server_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def process(self, data: Shareable, fl_ctx: FLContext):
"""
client_name = data.get_header(ServerCommandKey.FL_CLIENT)
reason = data.get_header(ServerCommandKey.REASON)
self.logger.warning(f"received dead job notification: {client_name=}; {reason=}")
server_runner = fl_ctx.get_prop(FLContextKey.RUNNER)
if server_runner:
server_runner.handle_dead_job(client_name, fl_ctx)
Expand Down
15 changes: 14 additions & 1 deletion nvflare/private/fed/server/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,19 @@ def update_job_run_status(self):
message=request,
)

def notify_dead_job(self, job_id: str, client_name: str, reason: str):
shareable = Shareable()
shareable.set_header(ServerCommandKey.FL_CLIENT, client_name)
shareable.set_header(ServerCommandKey.REASON, reason)
self.send_command_to_child_runner_process(
job_id=job_id,
command_name=ServerCommandNames.HANDLE_DEAD_JOB,
command_data=shareable,
timeout=0.0,
optional=True,
)
self.logger.warning(f"notified SJ of dead-job: {job_id=}; {client_name=}; {reason=}")

def send_command_to_child_runner_process(
self, job_id: str, command_name: str, command_data, timeout=5.0, optional=False
):
Expand All @@ -594,7 +607,7 @@ def send_command_to_child_runner_process(
targets=fqcn,
channel=CellChannel.SERVER_COMMAND,
topic=command_name,
request=request,
message=request,
optional=optional,
)
return None
Expand Down
9 changes: 4 additions & 5 deletions nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _try_to_get_task(self, client, fl_ctx, timeout=None, retry_interval=0.005):
return "", "", None

if self.current_wf.responder:
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="getTask")

task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx)

Expand Down Expand Up @@ -371,7 +371,6 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
try:
if self.current_wf is None:
return

self.current_wf.responder.handle_dead_job(client_name=client_name, fl_ctx=fl_ctx)
except Exception as e:
self.log_exception(
Expand Down Expand Up @@ -445,7 +444,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
return

if self.current_wf.responder:
self.current_wf.responder.process_job_heartbeat(fl_ctx)
self.current_wf.responder.process_job_heartbeat(fl_ctx, "submitTask")

wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None)
if wf_id is not None and wf_id != self.current_wf.id:
Expand Down Expand Up @@ -508,7 +507,7 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex
self.log_debug(fl_ctx, "received client job_heartbeat")
with self.wf_lock:
if self.current_wf and self.current_wf.responder:
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx)
self.current_wf.responder.process_job_heartbeat(fl_ctx=fl_ctx, reason="jobHeartbeat")
return make_reply(ReturnCode.OK)

def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
Expand All @@ -525,7 +524,7 @@ def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext)
return make_reply(ReturnCode.TASK_UNKNOWN)

if self.current_wf.responder:
self.current_wf.responder.process_job_heartbeat(fl_ctx)
self.current_wf.responder.process_job_heartbeat(fl_ctx, "taskCheck")

# filter task result
task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx)
Expand Down
18 changes: 18 additions & 0 deletions nvflare/private/fed/server/sys_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def get_spec(self):
authz_func=self.authorize_server_operation,
visible=True,
),
CommandSpec(
name="dead",
description="send dead client msg to SJ",
usage="dead <client-name>",
handler_func=self.dead_client,
authz_func=self.must_be_project_admin,
visible=False,
),
],
)

Expand Down Expand Up @@ -156,3 +164,13 @@ def report_resources(self, conn: Connection, args: List[str]):
table = conn.append_table(["Sites", "Resources"])
for k, v in site_resources.items():
table.add_row([str(k), str(v)])

def dead_client(self, conn: Connection, args: List[str]):
if len(args) != 3:
conn.append_error(f"Usage: {args[0]} client_name job_id")
return
client_name = args[1]
job_id = args[2]
engine = conn.app_ctx
engine.notify_dead_job(job_id, client_name, f"AdminCommand: {args[0]}")
conn.append_string(f"called notify_dead_job for client {client_name=} {job_id=}")

0 comments on commit 16730ec

Please sign in to comment.