diff --git a/nvflare/job_config/fed_job_config.py b/nvflare/job_config/fed_job_config.py index 480f613788..ea5f30f639 100644 --- a/nvflare/job_config/fed_job_config.py +++ b/nvflare/job_config/fed_job_config.py @@ -14,8 +14,11 @@ import builtins import inspect import json +import logging import os +import shlex import shutil +import subprocess import sys from enum import Enum from tempfile import TemporaryDirectory @@ -25,7 +28,7 @@ from nvflare.job_config.base_app_config import BaseAppConfig from nvflare.job_config.fed_app_config import FedAppConfig from nvflare.private.fed.app.fl_conf import FL_PACKAGES -from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner +from nvflare.private.fed.app.utils import kill_child_processes CONFIG = "config" CUSTOM = "custom" @@ -58,6 +61,7 @@ def __init__(self, job_name, min_clients, mandatory_clients=None) -> None: self.resource_specs: Dict[str, Dict] = {} self.custom_modules = [] + self.logger = logging.getLogger(self.__class__.__name__) def add_fed_app(self, app_name: str, fed_app: FedAppConfig): if not isinstance(fed_app, FedAppConfig): @@ -136,15 +140,31 @@ def simulator_run(self, workspace, clients=None, n_clients=None, threads=None, g with TemporaryDirectory() as job_root: self.generate_job_config(job_root) - simulator = SimulatorRunner( - job_folder=os.path.join(job_root, self.job_name), - workspace=workspace, - clients=clients, - n_clients=n_clients, - threads=threads, - gpu=gpu, - ) - simulator.run() + try: + command = ( + f"{sys.executable} -m nvflare.private.fed.app.simulator.simulator " + + os.path.join(job_root, self.job_name) + + " -w " + + workspace + ) + if clients: + command += " -c " + str(clients) + if n_clients: + command += " -n " + str(n_clients) + if threads: + command += " -t " + str(threads) + if gpu: + command += " -gpu " + str(gpu) + + new_env = os.environ.copy() + process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, env=new_env) + + process.wait() + + except KeyboardInterrupt: + self.logger.info("KeyboardInterrupt, terminate all the child processes.") + kill_child_processes(os.getpid()) + return -9 def _get_server_app(self, config_dir, custom_dir, fed_app): server_app = {"format_version": 2, "workflows": []}