From 2348139522c8238f5a76a80dc860294cfb537284 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Tue, 23 Jan 2024 15:04:46 -0800 Subject: [PATCH] Move workspace setup inside constructor --- .../fed/app/simulator/simulator_runner.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index d363450170..589173ccd2 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -94,6 +94,15 @@ def __init__( self.clients_created = 0 + running_dir = os.getcwd() + if self.workspace is None: + self.workspace = "simulator_workspace" + self.logger.warn( + f"Simulator workspace is not provided. Set it to the default location:" + f" {os.path.join(running_dir, self.workspace)}" + ) + self.workspace = os.path.join(running_dir, self.workspace) + def _generate_args( self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100 ): @@ -110,15 +119,6 @@ def _generate_args( return args def setup(self): - running_dir = os.getcwd() - if self.workspace is None: - self.workspace = "simulator_workspace" - self.logger.warn( - f"Simulator workspace is not provided. Set it to the default location:" - f" {os.path.join(running_dir, self.workspace)}" - ) - self.workspace = os.path.join(running_dir, self.workspace) - self.args = self._generate_args( self.job_folder, self.workspace, self.clients, self.n_clients, self.threads, self.gpu, self.max_clients ) @@ -348,7 +348,7 @@ def run(self): try: manager = Manager() return_dict = manager.dict() - process = Process(target=self.run_processs, args=(return_dict,)) + process = Process(target=self.run_process, args=(return_dict,)) process.start() process.join() run_status = self._get_return_code(return_dict, process, self.workspace) @@ -380,7 +380,7 @@ def _get_return_code(self, return_dict, process, workspace): self.logger.info(f"return_code from process.exitcode: {return_code}") return return_code - def run_processs(self, return_dict): + def run_process(self, return_dict): # run_status = self.simulator_run_main() try: run_status = mpm.run(