Skip to content

Commit

Permalink
Protect learner against unsafe jobs (#1705) (#1822)
Browse files Browse the repository at this point in the history
* protect learner against unsafe jobs

* make unsafe job global property

Co-authored-by: Yan Cheng <[email protected]>
  • Loading branch information
YuanTingHsieh and yanchengnv authored Jun 22, 2023
1 parent ca112fb commit e330ec6
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 44 deletions.
4 changes: 4 additions & 0 deletions nvflare/apis/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class Executor(FLComponent, ABC):
"""

def __init__(self):
FLComponent.__init__(self)
self.unsafe = False

@abstractmethod
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
"""Executes a task.
Expand Down
2 changes: 2 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ReturnCode(object):
MODEL_UNRECOGNIZED = "MODEL_UNRECOGNIZED"
VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN"
EMPTY_RESULT = "EMPTY_RESULT"
UNSAFE_JOB = "UNSAFE_JOB"

SERVER_NOT_READY = "SERVER_NOT_READY"

Expand Down Expand Up @@ -99,6 +100,7 @@ class ReservedKey(object):
JOB_RUN_NUMBER = "__job_run_number__"
JOB_DEPLOY_DETAIL = "__job_deploy_detail__"
FATAL_SYSTEM_ERROR = "__fatal_system_error__"
JOB_IS_UNSAFE = "__job_is_unsafe__"


class FLContextKey(object):
Expand Down
6 changes: 6 additions & 0 deletions nvflare/apis/fl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def get_job_id(self, default=None):
def get_identity_name(self, default=""):
return self._simple_get(ReservedKey.IDENTITY_NAME, default=default)

def set_job_is_unsafe(self, value: bool = True):
self.set_prop(ReservedKey.JOB_IS_UNSAFE, value, private=True, sticky=True)

def is_job_unsafe(self):
return self.get_prop(ReservedKey.JOB_IS_UNSAFE, False)

def get_run_abort_signal(self):
return self._simple_get(key=ReservedKey.RUN_ABORT_SIGNAL, default=None)

Expand Down
6 changes: 6 additions & 0 deletions nvflare/apis/fl_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def __init__(self, message, exception=None):
if exception:
self.__dict__.update(exception.__dict__)
self.message = message


class UnsafeJobError(Exception):
"""Raised when a job is detected to be unsafe"""

pass
52 changes: 28 additions & 24 deletions nvflare/app_common/executors/learner_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,34 @@ def __init__(
"""Key component to run learner on clients.
Args:
learner_id (str): id pointing to the learner object
train_task (str, optional): label to dispatch train task. Defaults to AppConstants.TASK_TRAIN.
submit_model_task (str, optional): label to dispatch submit model task. Defaults to AppConstants.TASK_SUBMIT_MODEL.
validate_task (str, optional): label to dispatch validation task. Defaults to AppConstants.TASK_VALIDATION.
learner_id (str): id of the learner object
train_task (str, optional): task name for train. Defaults to AppConstants.TASK_TRAIN.
submit_model_task (str, optional): task name for submit model. Defaults to AppConstants.TASK_SUBMIT_MODEL.
validate_task (str, optional): task name for validation. Defaults to AppConstants.TASK_VALIDATION.
"""
super().__init__()
self.learner_id = learner_id
self.learner = None
self.train_task = train_task
self.submit_model_task = submit_model_task
self.validate_task = validate_task
self.is_initialized = False

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize(fl_ctx)
elif event_type == EventType.ABORT_TASK:
if event_type == EventType.ABORT_TASK:
try:
if self.learner:
self.learner.abort(fl_ctx)
if not self.unsafe:
self.learner.abort(fl_ctx)
else:
self.log_warning(fl_ctx, f"skipped abort of unsafe learner {self.learner.__class__.__name__}")
except Exception as e:
self.log_exception(fl_ctx, f"learner abort exception: {secure_format_exception(e)}")
elif event_type == EventType.END_RUN:
self.finalize(fl_ctx)
if not self.unsafe:
self.finalize(fl_ctx)
elif self.learner:
self.log_warning(fl_ctx, f"skipped finalize of unsafe learner {self.learner.__class__.__name__}")

def initialize(self, fl_ctx: FLContext):
try:
Expand All @@ -66,26 +71,25 @@ def initialize(self, fl_ctx: FLContext):
if not isinstance(self.learner, Learner):
raise TypeError(f"learner must be Learner type. Got: {type(self.learner)}")
self.learner.initialize(engine.get_all_components(), fl_ctx)
except Exception as e:
except BaseException as e:
self.log_exception(fl_ctx, f"learner initialize exception: {secure_format_exception(e)}")
raise e

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"Client trainer got task: {task_name}")
if not self.is_initialized:
self.is_initialized = True
self.initialize(fl_ctx)

try:
if task_name == self.train_task:
return self.train(shareable, fl_ctx, abort_signal)
elif task_name == self.submit_model_task:
return self.submit_model(shareable, fl_ctx)
elif task_name == self.validate_task:
return self.validate(shareable, fl_ctx, abort_signal)
else:
self.log_error(fl_ctx, f"Could not handle task: {task_name}")
return make_reply(ReturnCode.TASK_UNKNOWN)
except Exception as e:
# Task execution error, return EXECUTION_EXCEPTION Shareable
self.log_exception(fl_ctx, f"learner execute exception: {secure_format_exception(e)}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
if task_name == self.train_task:
return self.train(shareable, fl_ctx, abort_signal)
elif task_name == self.submit_model_task:
return self.submit_model(shareable, fl_ctx)
elif task_name == self.validate_task:
return self.validate(shareable, fl_ctx, abort_signal)
else:
self.log_error(fl_ctx, f"Could not handle task: {task_name}")
return make_reply(ReturnCode.TASK_UNKNOWN)

def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
self.log_debug(fl_ctx, f"train abort signal: {abort_signal.triggered}")
Expand Down
83 changes: 63 additions & 20 deletions nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_exception import UnsafeJobError
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.fl_context_utils import add_job_audit_event
Expand Down Expand Up @@ -113,7 +114,8 @@ def _register_aux_message_handler(self, engine):
engine.register_aux_message_handler(topic=ReservedTopic.END_RUN, message_handle_func=self._handle_end_run)
engine.register_aux_message_handler(topic=ReservedTopic.ABORT_ASK, message_handle_func=self._handle_abort_task)

def _reply_and_audit(self, reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable:
@staticmethod
def _reply_and_audit(reply: Shareable, ref, msg, fl_ctx: FLContext) -> Shareable:
audit_event_id = add_job_audit_event(fl_ctx=fl_ctx, ref=ref, msg=msg)
reply.set_header(ReservedKey.AUDIT_EVENT_ID, audit_event_id)
return reply
Expand Down Expand Up @@ -167,14 +169,16 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:

executor = self.task_table.get(task.name)
if not executor:
self.log_error(fl_ctx, "bad task assignment: no executor available for task {}".format(task.name))
self.log_error(fl_ctx, f"bad task assignment: no executor available for task {task.name}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_UNKNOWN),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.TASK_UNKNOWN}",
)

executor_name = executor.__class__.__name__

self.log_debug(fl_ctx, "firing event EventType.BEFORE_TASK_DATA_FILTER")
self.fire_event(EventType.BEFORE_TASK_DATA_FILTER, fl_ctx)

Expand All @@ -193,10 +197,23 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:
if filter_list:
task_data = task.data
for f in filter_list:
filter_name = f.__class__.__name__
try:
task_data = f.process(task_data, fl_ctx)
except Exception:
self.log_exception(fl_ctx, "processing error in Task Data Filter {}".format(type(f)))
except UnsafeJobError:
self.log_exception(fl_ctx, f"UnsafeJobError from Task Data Filter {filter_name}")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(
fl_ctx, f"Processing error from Task Data Filter {filter_name}: {secure_format_exception(e)}"
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_DATA_FILTER_ERROR),
ref=server_audit_event_id,
Expand Down Expand Up @@ -225,8 +242,8 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:
fl_ctx.set_prop(FLContextKey.TASK_DATA, value=task.data, private=True, sticky=False)
self.fire_event(EventType.BEFORE_TASK_EXECUTION, fl_ctx)
try:
self.log_info(fl_ctx, "invoking task executor {}".format(type(executor)))
add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {type(executor)}")
self.log_info(fl_ctx, f"invoking task executor {executor_name}")
add_job_audit_event(fl_ctx=fl_ctx, msg=f"invoked executor {executor_name}")

with self.task_lock:
self.task_abort_signal = Signal()
Expand Down Expand Up @@ -255,10 +272,7 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:

if not isinstance(reply, Shareable):
self.log_error(
fl_ctx,
"bad result generated by executor {}: must be Shareable but got {}".format(
type(executor), type(reply)
),
fl_ctx, f"bad result generated by executor {executor_name}: must be Shareable but got {type(reply)}"
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_RESULT_ERROR),
Expand All @@ -269,7 +283,7 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:

except RuntimeError as e:
self.log_exception(
fl_ctx, f"Critical RuntimeError happened with Exception {secure_format_exception(e)}: Aborting the RUN!"
fl_ctx, f"RuntimeError from executor {executor_name}: {secure_format_exception(e)}: Aborting the job!"
)
self.asked_to_stop = True
return self._reply_and_audit(
Expand All @@ -278,8 +292,18 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.EXECUTION_RESULT_ERROR}",
)
except Exception:
self.log_exception(fl_ctx, "processing error in task executor {}".format(type(executor)))
except UnsafeJobError:
self.log_exception(fl_ctx, f"UnsafeJobError from executor {executor_name}")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(fl_ctx, f"Processing error from executor {executor_name}: {secure_format_exception(e)}")
return self._reply_and_audit(
reply=make_reply(ReturnCode.EXECUTION_EXCEPTION),
ref=server_audit_event_id,
Expand All @@ -305,10 +329,23 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:

if filter_list:
for f in filter_list:
filter_name = f.__class__.__name__
try:
reply = f.process(reply, fl_ctx)
except Exception:
self.log_exception(fl_ctx, "processing error in Task Result Filter {}".format(type(f)))
except UnsafeJobError:
self.log_exception(fl_ctx, f"UnsafeJobError from Task Result Filter {filter_name}")
executor.unsafe = True
fl_ctx.set_job_is_unsafe()
return self._reply_and_audit(
reply=make_reply(ReturnCode.UNSAFE_JOB),
ref=server_audit_event_id,
fl_ctx=fl_ctx,
msg=f"submit result: {ReturnCode.UNSAFE_JOB}",
)
except Exception as e:
self.log_exception(
fl_ctx, f"Processing error in Task Result Filter {filter_name}: {secure_format_exception(e)}"
)
return self._reply_and_audit(
reply=make_reply(ReturnCode.TASK_RESULT_FILTER_ERROR),
ref=server_audit_event_id,
Expand Down Expand Up @@ -346,17 +383,24 @@ def _process_task(self, task: TaskAssignment, fl_ctx: FLContext) -> Shareable:

return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK")

def _check_stop_conditions(self, fl_ctx: FLContext) -> bool:
if fl_ctx.is_job_unsafe():
self.log_info(fl_ctx, "stopped unsafe job!")
return True
if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "run abort signal received")
return True
return False

def _try_run(self):
while not self.asked_to_stop:
with self.engine.new_context() as fl_ctx:
if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "run abort signal received")
if self._check_stop_conditions(fl_ctx):
break

task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx)

if self.run_abort_signal.triggered:
self.log_info(fl_ctx, "run abort signal received")
if self._check_stop_conditions(fl_ctx):
break

time.sleep(task_fetch_interval)
Expand Down Expand Up @@ -431,7 +475,6 @@ def run(self, app_root, args):
finally:
# in case any task is still running, abort it
self._abort_current_task()

self.end_run_events_sequence("run method")

def init_run(self, app_root, args):
Expand Down

0 comments on commit e330ec6

Please sign in to comment.