diff --git a/nvflare/apis/executor.py b/nvflare/apis/executor.py index 98d0f06b8a..ce6db35b81 100644 --- a/nvflare/apis/executor.py +++ b/nvflare/apis/executor.py @@ -31,6 +31,10 @@ class Executor(FLComponent, ABC): """ + def __init__(self): + FLComponent.__init__(self) + self.unsafe = False + @abstractmethod def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: """Executes a task. diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index b35747fcfa..b951b9e451 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -38,6 +38,7 @@ class ReturnCode(object): MODEL_UNRECOGNIZED = "MODEL_UNRECOGNIZED" VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN" EMPTY_RESULT = "EMPTY_RESULT" + UNSAFE_JOB = "UNSAFE_JOB" SERVER_NOT_READY = "SERVER_NOT_READY" @@ -99,6 +100,7 @@ class ReservedKey(object): JOB_RUN_NUMBER = "__job_run_number__" JOB_DEPLOY_DETAIL = "__job_deploy_detail__" FATAL_SYSTEM_ERROR = "__fatal_system_error__" + JOB_IS_UNSAFE = "__job_is_unsafe__" class FLContextKey(object): diff --git a/nvflare/apis/fl_context.py b/nvflare/apis/fl_context.py index a3283eb053..bee5103661 100644 --- a/nvflare/apis/fl_context.py +++ b/nvflare/apis/fl_context.py @@ -209,6 +209,12 @@ def get_job_id(self, default=None): def get_identity_name(self, default=""): return self._simple_get(ReservedKey.IDENTITY_NAME, default=default) + def set_job_is_unsafe(self, value: bool = True): + self.set_prop(ReservedKey.JOB_IS_UNSAFE, value, private=True, sticky=True) + + def is_job_unsafe(self): + return self.get_prop(ReservedKey.JOB_IS_UNSAFE, False) + def get_run_abort_signal(self): return self._simple_get(key=ReservedKey.RUN_ABORT_SIGNAL, default=None) diff --git a/nvflare/apis/fl_exception.py b/nvflare/apis/fl_exception.py index 9cc8be0f00..5e51ebda18 100644 --- a/nvflare/apis/fl_exception.py +++ b/nvflare/apis/fl_exception.py @@ -27,3 +27,9 @@ def __init__(self, message, exception=None): if exception: self.__dict__.update(exception.__dict__) self.message = message + + +class UnsafeJobError(Exception): + """Raised when a job is detected to be unsafe""" + + pass diff --git a/nvflare/app_common/executors/learner_executor.py b/nvflare/app_common/executors/learner_executor.py index 65ebbc09b4..5304acdeda 100644 --- a/nvflare/app_common/executors/learner_executor.py +++ b/nvflare/app_common/executors/learner_executor.py @@ -35,10 +35,10 @@ def __init__( """Key component to run learner on clients. Args: - learner_id (str): id pointing to the learner object - train_task (str, optional): label to dispatch train task. Defaults to AppConstants.TASK_TRAIN. - submit_model_task (str, optional): label to dispatch submit model task. Defaults to AppConstants.TASK_SUBMIT_MODEL. - validate_task (str, optional): label to dispatch validation task. Defaults to AppConstants.TASK_VALIDATION. + learner_id (str): id of the learner object + train_task (str, optional): task name for train. Defaults to AppConstants.TASK_TRAIN. + submit_model_task (str, optional): task name for submit model. Defaults to AppConstants.TASK_SUBMIT_MODEL. + validate_task (str, optional): task name for validation. Defaults to AppConstants.TASK_VALIDATION. """ super().__init__() self.learner_id = learner_id @@ -46,18 +46,23 @@ def __init__( self.train_task = train_task self.submit_model_task = submit_model_task self.validate_task = validate_task + self.is_initialized = False def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.initialize(fl_ctx) - elif event_type == EventType.ABORT_TASK: + if event_type == EventType.ABORT_TASK: try: if self.learner: - self.learner.abort(fl_ctx) + if not self.unsafe: + self.learner.abort(fl_ctx) + else: + self.log_warning(fl_ctx, f"skipped abort of unsafe learner {self.learner.__class__.__name__}") except Exception as e: self.log_exception(fl_ctx, f"learner abort exception: {secure_format_exception(e)}") elif event_type == EventType.END_RUN: - self.finalize(fl_ctx) + if not self.unsafe: + self.finalize(fl_ctx) + elif self.learner: + self.log_warning(fl_ctx, f"skipped finalize of unsafe learner {self.learner.__class__.__name__}") def initialize(self, fl_ctx: FLContext): try: @@ -66,26 +71,25 @@ def initialize(self, fl_ctx: FLContext): if not isinstance(self.learner, Learner): raise TypeError(f"learner must be Learner type. Got: {type(self.learner)}") self.learner.initialize(engine.get_all_components(), fl_ctx) - except Exception as e: + except BaseException as e: self.log_exception(fl_ctx, f"learner initialize exception: {secure_format_exception(e)}") + raise e def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_info(fl_ctx, f"Client trainer got task: {task_name}") + if not self.is_initialized: + self.is_initialized = True + self.initialize(fl_ctx) - try: - if task_name == self.train_task: - return self.train(shareable, fl_ctx, abort_signal) - elif task_name == self.submit_model_task: - return self.submit_model(shareable, fl_ctx) - elif task_name == self.validate_task: - return self.validate(shareable, fl_ctx, abort_signal) - else: - self.log_error(fl_ctx, f"Could not handle task: {task_name}") - return make_reply(ReturnCode.TASK_UNKNOWN) - except Exception as e: - # Task execution error, return EXECUTION_EXCEPTION Shareable - self.log_exception(fl_ctx, f"learner execute exception: {secure_format_exception(e)}") - return make_reply(ReturnCode.EXECUTION_EXCEPTION) + if task_name == self.train_task: + return self.train(shareable, fl_ctx, abort_signal) + elif task_name == self.submit_model_task: + return self.submit_model(shareable, fl_ctx) + elif task_name == self.validate_task: + return self.validate(shareable, fl_ctx, abort_signal) + else: + self.log_error(fl_ctx, f"Could not handle task: {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: self.log_debug(fl_ctx, f"train abort signal: {abort_signal.triggered}") diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index 86ee19c83e..9d43d8fb1b 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -19,6 +19,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ReturnCode from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import UnsafeJobError from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import add_job_audit_event @@ -113,7 +114,8 @@ def _register_aux_message_handler(self, engine): engine.register_aux_message_handler(topic=ReservedTopic.END_RUN, message_handle_func=self._handle_end_run) engine.register_aux_message_handler(topic=ReservedTopic.ABORT_ASK, message_handle_func=self._handle_abort_task) - def _reply_and_audit(self, reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable: + @staticmethod + def _reply_and_audit(reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable: audit_event_id = add_job_audit_event(fl_ctx=fl_ctx, ref=ref, msg=msg) reply.set_header(ReservedKey.AUDIT_EVENT_ID, audit_event_id) return reply @@ -167,7 +169,7 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: executor = self.task_table.get(task.name) if not executor: - self.log_error(fl_ctx, "bad task assignment: no executor available for task {}".format(task.name)) + self.log_error(fl_ctx, f"bad task assignment: no executor available for task {task.name}") return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_UNKNOWN), ref=server_audit_event_id, @@ -175,6 +177,8 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: msg=f"submit result: {ReturnCode.TASK_UNKNOWN}", ) + executor_name = executor.__class__.__name__ + self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER") self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx) @@ -193,10 +197,23 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: if filter_list: task_data = task.data for f in filter_list: + filter_name = f.__class__.__name__ try: task_data = f.process(task_data, fl_ctx) - except Exception: - self.log_exception(fl_ctx, "processing error in Task Data Filter {}".format(type(f))) + except UnsafeJobError: + self.log_exception(fl_ctx, f"UnsafeJobError from Task Data Filter {filter_name}") + executor.unsafe = True + fl_ctx.set_job_is_unsafe() + return self._reply_and_audit( + reply=make_reply(ReturnCode.UNSAFE_JOB), + ref=server_audit_event_id, + fl_ctx=fl_ctx, + msg=f"submit result: {ReturnCode.UNSAFE_JOB}", + ) + except Exception as e: + self.log_exception( + fl_ctx, f"Processing error from Task Data Filter {filter_name}: {secure_format_exception(e)}" + ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR), ref=server_audit_event_id, @@ -225,8 +242,8 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False) self.fire_event(EventType.BEFORE_TASK_EXECUTION, fl_ctx) try: - self.log_info(fl_ctx, "invoking task executor {}".format(type(executor))) - add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {type(executor)}") + self.log_info(fl_ctx, f"invoking task executor {executor_name}") + add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {executor_name}") with self.task_lock: self.task_abort_signal = Signal() @@ -255,10 +272,7 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: if not isinstance(reply, Shareable): self.log_error( - fl_ctx, - "bad result generated by executor {}: must be Shareable but got {}".format( - type(executor), type(reply) - ), + fl_ctx, f"bad result generated by executor {executor_name}: must be Shareable but got {type(reply)}" ) return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR), @@ -269,7 +283,7 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: except RuntimeError as e: self.log_exception( - fl_ctx, f"Critical RuntimeError happened with Exception {secure_format_exception(e)}: Aborting the RUN!" + fl_ctx, f"RuntimeError from executor {executor_name}: {secure_format_exception(e)}: Aborting the job!" ) self.asked_to_stop = True return self._reply_and_audit( @@ -278,8 +292,18 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: fl_ctx=fl_ctx, msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}", ) - except Exception: - self.log_exception(fl_ctx, "processing error in task executor {}".format(type(executor))) + except UnsafeJobError: + self.log_exception(fl_ctx, f"UnsafeJobError from executor {executor_name}") + executor.unsafe = True + fl_ctx.set_job_is_unsafe() + return self._reply_and_audit( + reply=make_reply(ReturnCode.UNSAFE_JOB), + ref=server_audit_event_id, + fl_ctx=fl_ctx, + msg=f"submit result: {ReturnCode.UNSAFE_JOB}", + ) + except Exception as e: + self.log_exception(fl_ctx, f"Processing error from executor {executor_name}: {secure_format_exception(e)}") return self._reply_and_audit( reply=make_reply(ReturnCode.EXECUTION_EXCEPTION), ref=server_audit_event_id, @@ -305,10 +329,23 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: if filter_list: for f in filter_list: + filter_name = f.__class__.__name__ try: reply = f.process(reply, fl_ctx) - except Exception: - self.log_exception(fl_ctx, "processing error in Task Result Filter {}".format(type(f))) + except UnsafeJobError: + self.log_exception(fl_ctx, f"UnsafeJobError from Task Result Filter {filter_name}") + executor.unsafe = True + fl_ctx.set_job_is_unsafe() + return self._reply_and_audit( + reply=make_reply(ReturnCode.UNSAFE_JOB), + ref=server_audit_event_id, + fl_ctx=fl_ctx, + msg=f"submit result: {ReturnCode.UNSAFE_JOB}", + ) + except Exception as e: + self.log_exception( + fl_ctx, f"Processing error in Task Result Filter {filter_name}: {secure_format_exception(e)}" + ) return self._reply_and_audit( reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR), ref=server_audit_event_id, @@ -346,17 +383,24 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable: return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK") + def _check_stop_conditions(self, fl_ctx: FLContext) -> bool: + if fl_ctx.is_job_unsafe(): + self.log_info(fl_ctx, "stopped unsafe job!") + return True + if self.run_abort_signal.triggered: + self.log_info(fl_ctx, "run abort signal received") + return True + return False + def _try_run(self): while not self.asked_to_stop: with self.engine.new_context() as fl_ctx: - if self.run_abort_signal.triggered: - self.log_info(fl_ctx, "run abort signal received") + if self._check_stop_conditions(fl_ctx): break task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx) - if self.run_abort_signal.triggered: - self.log_info(fl_ctx, "run abort signal received") + if self._check_stop_conditions(fl_ctx): break time.sleep(task_fetch_interval) @@ -431,7 +475,6 @@ def run(self, app_root, args): finally: # in case any task is still running, abort it self._abort_current_task() - self.end_run_events_sequence("run method") def init_run(self, app_root, args):