Skip to content

Commit

Permalink
Fix AV FL Issues (#2053)
Browse files Browse the repository at this point in the history
* fix hb timeout; add retry logic for task result submit

* increased timeout
  • Loading branch information
yanchengnv authored Oct 4, 2023
1 parent 6aebf9c commit 55bbdf8
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 13 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class ReservedTopic(object):
ABORT_ASK = "__abort_task__"
AUX_COMMAND = "__aux_command__"
JOB_HEART_BEAT = "__job_heartbeat__"
TASK_CHECK = "__task_check__"


class AdminCommandNames(object):
Expand Down
5 changes: 5 additions & 0 deletions nvflare/apis/impl/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
if not self._dead_client_reports.get(client_name):
self._dead_client_reports[client_name] = time.time()

def process_task_check(self, task_id: str, fl_ctx: FLContext):
with self._task_lock:
# task_id is the uuid associated with the client_task
return self._client_task_map.get(task_id, None)

def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext):
"""Called to process a submission from one client.
Expand Down
13 changes: 13 additions & 0 deletions nvflare/apis/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul
"""
pass

@abstractmethod
def process_task_check(self, task_id: str, fl_ctx: FLContext):
"""Called by the Engine to check whether a specified task still exists.
Args:
task_id: the id of the task
fl_ctx: the FLContext
Returns: the ClientTask object if exists; None otherwise
"""
pass

@abstractmethod
def handle_dead_job(self, client_name: str, fl_ctx: FLContext):
"""Called by the Engine to handle the case that the job on the client is dead.
Expand Down
112 changes: 103 additions & 9 deletions nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from nvflare.security.logging import secure_format_exception
from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector

_TASK_CHECK_RESULT_OK = 0
_TASK_CHECK_RESULT_TRY_AGAIN = 1
_TASK_CHECK_RESULT_TASK_GONE = 2


class ClientRunnerConfig(object):
def __init__(
Expand Down Expand Up @@ -107,7 +111,8 @@ def __init__(
self.task_lock = threading.Lock()
self.end_run_fired = False
self.end_run_lock = threading.Lock()

self.task_check_timeout = 5.0
self.task_check_interval = 5.0
self._register_aux_message_handler(engine)

def _register_aux_message_handler(self, engine):
Expand Down Expand Up @@ -473,19 +478,108 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool):
if cookie_jar:
task_reply.set_cookie_jar(cookie_jar)

reply_sent = self.engine.send_task_result(task_reply, fl_ctx)
if reply_sent:
self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id))
else:
self.log_error(
fl_ctx,
"failed to send result to server for task: name={}, id={}".format(task.name, task.task_id),
)
self._send_task_result(task_reply, task.task_id, fl_ctx)
self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT")
self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx)

return task_fetch_interval, True

def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext):
try_count = 1
while True:
self.log_info(fl_ctx, f"try #{try_count}: sending task result to server")

if self.asked_to_stop:
self.log_info(fl_ctx, "job aborted: stopped trying to send result")
return False

try_count += 1
rc = self._try_send_result_once(result, task_id, fl_ctx)

if rc == _TASK_CHECK_RESULT_OK:
return True
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
return False
else:
# retry
time.sleep(self.task_check_interval)

def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext):
# wait until server is ready to receive
while True:
if self.asked_to_stop:
return _TASK_CHECK_RESULT_TASK_GONE

rc = self._check_task_once(task_id, fl_ctx)
if rc == _TASK_CHECK_RESULT_OK:
break
elif rc == _TASK_CHECK_RESULT_TASK_GONE:
return rc
else:
# try again
time.sleep(self.task_check_interval)

# try to send the result
self.log_info(fl_ctx, "start to send task result to server")
reply_sent = self.engine.send_task_result(result, fl_ctx)
if reply_sent:
self.log_info(fl_ctx, "task result sent to server")
return _TASK_CHECK_RESULT_OK
else:
self.log_error(fl_ctx, "failed to send task result to server - will try again")
return _TASK_CHECK_RESULT_TRY_AGAIN

def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int:
"""This method checks whether the server is still waiting for the specified task.
The real reason for this method is to fight against unstable network connections.
We try to make sure that when we send task result to the server, the connection is available.
If the task check succeeds, then the network connection is likely to be available.
Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out).
Args:
task_id:
fl_ctx:
Returns:
"""
self.log_info(fl_ctx, "checking task ...")
task_check_req = Shareable()
task_check_req.set_header(ReservedKey.TASK_ID, task_id)
resp = self.engine.send_aux_request(
targets=[FQCN.ROOT_SERVER],
topic=ReservedTopic.TASK_CHECK,
request=task_check_req,
timeout=self.task_check_timeout,
fl_ctx=fl_ctx,
optional=True,
)
if resp and isinstance(resp, dict):
reply = resp.get(FQCN.ROOT_SERVER)
if not isinstance(reply, Shareable):
self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}")
return _TASK_CHECK_RESULT_TRY_AGAIN

rc = reply.get_return_code()
if rc == ReturnCode.OK:
return _TASK_CHECK_RESULT_OK
elif rc == ReturnCode.COMMUNICATION_ERROR:
self.log_error(fl_ctx, f"failed task_check: {rc}")
return _TASK_CHECK_RESULT_TRY_AGAIN
elif rc == ReturnCode.SERVER_NOT_READY:
self.log_error(fl_ctx, f"server rejected task_check: {rc}")
return _TASK_CHECK_RESULT_TRY_AGAIN
elif rc == ReturnCode.TASK_UNKNOWN:
self.log_error(fl_ctx, f"task no longer exists on server: {rc}")
return _TASK_CHECK_RESULT_TASK_GONE
else:
# this should never happen
self.log_error(fl_ctx, f"programming error: received {rc} from server")
return _TASK_CHECK_RESULT_OK # try to push the result regardless
else:
self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}")
return _TASK_CHECK_RESULT_TRY_AGAIN

def run(self, app_root, args):
self.init_run(app_root, args)

Expand Down
13 changes: 9 additions & 4 deletions nvflare/private/fed/client/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
cell: Cell = None,
client_register_interval=2,
timeout=5.0,
maint_msg_timeout=5.0,
):
"""To init the Communicator.
Expand All @@ -85,7 +86,7 @@ def __init__(
self.compression = compression
self.client_register_interval = client_register_interval
self.timeout = timeout

self.maint_msg_timeout = maint_msg_timeout
self.logger = logging.getLogger(self.__class__.__name__)

def client_registration(self, client_name, servers, project_name):
Expand Down Expand Up @@ -129,7 +130,7 @@ def client_registration(self, client_name, servers, project_name):
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.Register,
request=login_message,
timeout=self.timeout,
timeout=self.maint_msg_timeout,
)
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.UNAUTHENTICATED:
Expand Down Expand Up @@ -297,7 +298,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext):
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.Quit,
request=quit_message,
timeout=self.timeout,
timeout=self.maint_msg_timeout,
)
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.UNAUTHENTICATED:
Expand Down Expand Up @@ -335,9 +336,13 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C
channel=CellChannel.SERVER_MAIN,
topic=CellChannelTopic.HEART_BEAT,
request=heartbeat_message,
timeout=self.timeout,
timeout=self.maint_msg_timeout,
)
return_code = result.get_header(MessageHeaderKey.RETURN_CODE)

if return_code != ReturnCode.OK:
self.logger.error(f"heartbeat error: {return_code}")

if return_code == ReturnCode.UNAUTHENTICATED:
unauthenticated = result.get_header(MessageHeaderKey.ERROR)
raise FLCommunicationError("error:client_quit " + unauthenticated)
Expand Down
1 change: 1 addition & 0 deletions nvflare/private/fed/client/fed_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
cell=cell,
client_register_interval=client_args.get("client_register_interval", 2.0),
timeout=client_args.get("communication_timeout", 30.0),
maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0),
)

self.secure_train = secure_train
Expand Down
24 changes: 24 additions & 0 deletions nvflare/private/fed/server/server_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def _register_aux_message_handler(self, engine):
topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat
)

engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check)

def _execute_run(self):
while self.current_wf_index < len(self.config.workflows):
wf = self.config.workflows[self.current_wf_index]
Expand Down Expand Up @@ -500,6 +502,28 @@ def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContex
self.log_info(fl_ctx, "received client job_heartbeat aux request")
return make_reply(ReturnCode.OK)

def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable:
task_id = request.get_header(ReservedHeaderKey.TASK_ID)
if not task_id:
self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request")
return make_reply(ReturnCode.BAD_REQUEST_DATA)

self.log_info(fl_ctx, f"received task_check on task {task_id}")

with self.wf_lock:
if self.current_wf is None or self.current_wf.responder is None:
self.log_info(fl_ctx, "no current workflow - dropped task_check.")
return make_reply(ReturnCode.TASK_UNKNOWN)

# filter task result
task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx)
if task:
self.log_info(fl_ctx, f"task {task_id} is still good")
return make_reply(ReturnCode.OK)
else:
self.log_info(fl_ctx, f"task {task_id} is not found")
return make_reply(ReturnCode.TASK_UNKNOWN)

def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False):
self.status = "done"
self.abort_signal.trigger(value=True)
Expand Down

0 comments on commit 55bbdf8

Please sign in to comment.