Skip to content

Commit

Permalink
Fixed the client_executor improper lock use. (NVIDIA#2282)
Browse files Browse the repository at this point in the history
Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
yhwen and chesterxgchen authored Jan 16, 2024
1 parent 304c1d8 commit e4f8d41
Showing 1 changed file with 106 additions and 106 deletions.
212 changes: 106 additions & 106 deletions nvflare/private/fed/client/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
import time
from abc import ABC, abstractmethod

from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey
from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey, SystemConfigs
from nvflare.apis.resource_manager_spec import ResourceManagerSpec
from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode
from nvflare.fuel.f3.cellnet.core_cell import FQCN
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message
from nvflare.private.fed.utils.fed_utils import get_return_code
from nvflare.security.logging import secure_format_exception, secure_log_traceback
Expand Down Expand Up @@ -138,6 +139,10 @@ def __init__(self, client, startup):
self.run_processes = {}
self.lock = threading.Lock()

self.job_query_timeout = ConfigService.get_float_var(
name="job_query_timeout", conf=SystemConfigs.APPLICATION_CONF, default=5.0
)

def start_app(
self,
client,
Expand Down Expand Up @@ -216,10 +221,9 @@ def start_app(
thread.start()

def notify_job_status(self, job_id, job_status):
with self.lock:
run_process = self.run_processes.get(job_id)
if run_process:
run_process[RunProcessKey.STATUS] = job_status
run_process = self.run_processes.get(job_id)
if run_process:
run_process[RunProcessKey.STATUS] = job_status

def check_status(self, job_id):
"""Checks the status of the running client.
Expand All @@ -231,9 +235,8 @@ def check_status(self, job_id):
A client status message
"""
try:
with self.lock:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
return get_status_message(process_status)
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
return get_status_message(process_status)
except Exception as e:
self.logger.error(f"check_status execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
Expand All @@ -249,23 +252,23 @@ def get_run_info(self, job_id):
A dict of run information.
"""
try:
with self.lock:
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_STATS,
request=request,
optional=True,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
run_info = return_data.payload
return run_info
else:
return {}
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_STATS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
run_info = return_data.payload
return run_info
else:
return {}
except Exception as e:
self.logger.error(f"get_run_info execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
Expand All @@ -281,23 +284,23 @@ def get_errors(self, job_id):
A dict of error information.
"""
try:
with self.lock:
data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_ERRORS,
request=request,
optional=True,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
errors_info = return_data.payload
return errors_info
else:
return None
data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
return_data = self.client.cell.send_request(
target=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.SHOW_ERRORS,
request=request,
optional=True,
timeout=self.job_query_timeout,
)
return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE)
if return_code == ReturnCode.OK:
errors_info = return_data.payload
return errors_info
else:
return None
except Exception as e:
self.logger.error(f"get_errors execution exception: {secure_format_exception(e)}.")
secure_log_traceback()
Expand All @@ -310,17 +313,16 @@ def reset_errors(self, job_id):
job_id: the job_id
"""
try:
with self.lock:
data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.RESET_ERRORS,
message=request,
optional=True,
)
data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.RESET_ERRORS,
message=request,
optional=True,
)

except Exception as e:
self.logger.error(f"reset_errors execution exception: {secure_format_exception(e)}.")
Expand All @@ -332,41 +334,41 @@ def abort_app(self, job_id):
Args:
job_id: the job_id
"""
with self.lock:
# When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated,
# Use retry to avoid print out the error stack trace.
retry = 1
while retry >= 0:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
try:
# When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated,
# Use retry to avoid print out the error stack trace.
retry = 1
while retry >= 0:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
try:
with self.lock:
child_process = self.run_processes[job_id][RunProcessKey.CHILD_PROCESS]
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT,
message=request,
optional=True,
)
self.logger.debug("abort sent to worker")
t = threading.Thread(target=self._terminate_process, args=[child_process, job_id])
t.start()
t.join()
break
except Exception as e:
if retry == 0:
self.logger.error(
f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}."
)
secure_log_traceback()
retry -= 1
time.sleep(5.0)
else:
self.logger.info(f"Client worker process for run: {job_id} was already terminated.")
data = {}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT,
message=request,
optional=True,
)
self.logger.debug("abort sent to worker")
t = threading.Thread(target=self._terminate_process, args=[child_process, job_id])
t.start()
t.join()
break
except Exception as e:
if retry == 0:
self.logger.error(
f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}."
)
secure_log_traceback()
retry -= 1
time.sleep(5.0)
else:
self.logger.info(f"Client worker process for run: {job_id} was already terminated.")
break

self.logger.info("Client worker process is terminated.")

Expand Down Expand Up @@ -405,25 +407,23 @@ def abort_task(self, job_id):
Args:
job_id: the job_id
"""
with self.lock:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
data = {"command": AdminCommandNames.ABORT_TASK, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT_TASK,
message=request,
optional=True,
)
self.logger.debug("abort_task sent")
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED)
if process_status == ClientStatus.STARTED:
data = {"command": AdminCommandNames.ABORT_TASK, "data": {}}
fqcn = FQCN.join([self.client.client_name, job_id])
request = new_cell_message({}, data)
self.client.cell.fire_and_forget(
targets=fqcn,
channel=CellChannel.CLIENT_COMMAND,
topic=AdminCommandNames.ABORT_TASK,
message=request,
optional=True,
)
self.logger.debug("abort_task sent")

def _wait_child_process_finish(self, client, job_id, allocated_resource, token, resource_manager, workspace):
self.logger.info(f"run ({job_id}): waiting for child worker process to finish.")
with self.lock:
child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS)
child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS)
if child_process:
child_process.wait()

Expand Down Expand Up @@ -452,13 +452,13 @@ def _wait_child_process_finish(self, client, job_id, allocated_resource, token,
resource_manager.free_resources(
resources=allocated_resource, token=token, fl_ctx=client.engine.new_context()
)
self.run_processes.pop(job_id, None)
with self.lock:
self.run_processes.pop(job_id, None)
self.logger.debug(f"run ({job_id}): child worker resources freed.")

def get_status(self, job_id):
with self.lock:
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED)
return process_status
process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED)
return process_status

def get_run_processes_keys(self):
with self.lock:
Expand Down

0 comments on commit e4f8d41

Please sign in to comment.