From f208127551f257c0a6a430873b86675f22ea1d72 Mon Sep 17 00:00:00 2001 From: Yan Cheng <58191769+yanchengnv@users.noreply.github.com> Date: Fri, 13 Oct 2023 13:48:04 -0400 Subject: [PATCH] Support 3rd-party training system integration with FLARE (#2074) * support av ipc based model exchange * polish * add license text * support child-based comm * reformat * support client side listening * reorg * formatting * added license text * fix f-str * address PR comments --- integration/av/__init__.py | 13 + integration/av/trainer.py | 99 ++++ nvflare/app_common/executors/ipc_exchanger.py | 406 ++++++++++++++++ nvflare/client/__init__.py | 13 + nvflare/client/defs.py | 149 ++++++ nvflare/client/ipc_agent.py | 446 ++++++++++++++++++ nvflare/fuel/f3/cellnet/cell.py | 4 + nvflare/fuel/f3/cellnet/connector_manager.py | 7 +- nvflare/fuel/f3/drivers/net_utils.py | 45 ++ nvflare/fuel/f3/sfm/conn_manager.py | 12 +- 10 files changed, 1190 insertions(+), 4 deletions(-) create mode 100644 integration/av/__init__.py create mode 100644 integration/av/trainer.py create mode 100644 nvflare/app_common/executors/ipc_exchanger.py create mode 100644 nvflare/client/__init__.py create mode 100644 nvflare/client/defs.py create mode 100644 nvflare/client/ipc_agent.py diff --git a/integration/av/__init__.py b/integration/av/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/integration/av/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/integration/av/trainer.py b/integration/av/trainer.py new file mode 100644 index 0000000000..40135f4329 --- /dev/null +++ b/integration/av/trainer.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +from nvflare.client.defs import RC, AgentClosed, MetaKey, TaskResult +from nvflare.client.ipc_agent import IPCAgent + +NUMPY_KEY = "numpy_key" + + +def main(): + + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", "-w", type=str, help="workspace folder", required=False, default=".") + parser.add_argument("--site_name", "-s", type=str, help="flare site name", required=True) + parser.add_argument("--agent_id", "-a", type=str, help="agent id", required=True) + parser.add_argument("--job_id", "-j", type=str, help="flare job id", required=False, default="") + parser.add_argument("--site_url", "-u", type=str, help="flare site url", required=False, default="") + + args = parser.parse_args() + + agent = IPCAgent( + root_url="grpc://server:8002", + flare_site_name=args.site_name, + agent_id=args.agent_id, + workspace_dir=args.workspace, + secure_mode=True, + submit_result_timeout=2.0, + flare_site_heartbeat_timeout=120.0, + job_id=args.job_id, + flare_site_url=args.site_url, + ) + + agent.start() + + while True: + print("getting task ...") + try: + task = agent.get_task() + except AgentClosed: + print("agent closed - exit") + break + + print(f"got task: {task}") + rc, meta, result = train(task.meta, task.data) + submitted = agent.submit_result(TaskResult(data=result, meta=meta, return_code=rc)) + print(f"result submitted: {submitted}") + + agent.stop() + + +def train(meta, model): + current_round = meta.get(MetaKey.CURRENT_ROUND) + total_rounds = meta.get(MetaKey.TOTAL_ROUND) + + # Ensure that data is of type weights. Extract model data + np_data = model + + # Display properties. + print(f"Model: \n{np_data}") + print(f"Current Round: {current_round}") + print(f"Total Rounds: {total_rounds}") + + # Doing some dummy training. + if np_data: + if NUMPY_KEY in np_data: + np_data[NUMPY_KEY] += 1.0 + else: + print("error: numpy_key not found in model.") + return RC.BAD_TASK_DATA, None, None + else: + print("No model weights found in shareable.") + return RC.BAD_TASK_DATA, None, None + + # Save local numpy model + print(f"Model after training: {np_data}") + + # Prepare a DXO for our updated model. Create shareable and return + return RC.OK, {MetaKey.NUM_STEPS_CURRENT_ROUND: 1}, np_data + + +if __name__ == "__main__": + main() diff --git a/nvflare/app_common/executors/ipc_exchanger.py b/nvflare/app_common/executors/ipc_exchanger.py new file mode 100644 index 0000000000..f483c6d6c4 --- /dev/null +++ b/nvflare/app_common/executors/ipc_exchanger.py @@ -0,0 +1,406 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from typing import Union + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AppConstants +from nvflare.client import defs +from nvflare.fuel.f3.cellnet.cell import Cell, Message, MessageHeaderKey +from nvflare.fuel.f3.cellnet.cell import ReturnCode as CellReturnCode +from nvflare.fuel.f3.cellnet.cell import new_message +from nvflare.fuel.f3.cellnet.utils import make_reply as make_cell_reply + + +class _TaskContext: + def __init__(self, task_name: str, task_id: str, fl_ctx: FLContext): + self.task_id = task_id + self.task_name = task_name + self.fl_ctx = fl_ctx + self.send_rc = None + self.result_rc = None + self.result_error = None + self.result = None + self.result_received_time = None + self.result_waiter = threading.Event() + + def __str__(self): + return f"'{self.task_name} {self.task_id}'" + + +class IPCExchanger(Executor): + def __init__( + self, + send_task_timeout=5.0, + agent_ready_timeout=60.0, + agent_heartbeat_timeout=600.0, + agent_is_child=False, + ): + """Constructor of IPCExchanger + + Args: + send_task_timeout: when sending task to Agent, how long to wait for response + agent_ready_timeout: how long to wait for the agent to be connected + agent_heartbeat_timeout: max time allowed to miss heartbeats from the agent + agent_is_child: whether the agent will be a child cell. + """ + Executor.__init__(self) + self.flare_agent_fqcn = None + self.agent_ready_waiter = threading.Event() + self.agent_ready_timeout = agent_ready_timeout + self.agent_heartbeat_timeout = agent_heartbeat_timeout + self.send_task_timeout = send_task_timeout + self.agent_is_child = agent_is_child + self.internal_listener_url = None + self.last_agent_ack_time = time.time() + self.engine = None + self.cell = None + self.is_done = False + self.task_ctx = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.engine = fl_ctx.get_engine() + self.cell = self.engine.get_cell() + + self.cell.register_request_cb( + channel=defs.CHANNEL, + topic=defs.TOPIC_SUBMIT_RESULT, + cb=self._receive_result, + ) + + # get meta + meta = fl_ctx.get_prop(FLContextKey.JOB_META) + assert isinstance(meta, dict) + agent_id = meta.get(defs.JOB_META_KEY_AGENT_ID) + if not agent_id: + self.system_panic(reason=f"missing {defs.JOB_META_KEY_AGENT_ID} from job meta", fl_ctx=fl_ctx) + return + + client_name = fl_ctx.get_identity_name() + self.flare_agent_fqcn = defs.agent_site_fqcn(client_name, agent_id) + + if self.agent_is_child: + job_id = fl_ctx.get_job_id() + self.flare_agent_fqcn = defs.agent_site_fqcn(client_name, agent_id, job_id) + self.cell.make_internal_listener() + self.internal_listener_url = self.cell.get_internal_listener_url() + self.logger.info(f"URL for Agent: {self.internal_listener_url}") + + self.log_info(fl_ctx, f"Flare Agent FQCN: {self.flare_agent_fqcn}") + t = threading.Thread(target=self._maintain, daemon=True) + t.start() + elif event_type == EventType.END_RUN: + self.is_done = True + self._say_goodbye() + + def _say_goodbye(self): + # say goodbye to agent + self.logger.info(f"job done - say goodbye to {self.flare_agent_fqcn}") + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_BYE, + target=self.flare_agent_fqcn, + request=new_message(), + optional=True, + timeout=2.0, + ) + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc != CellReturnCode.OK: + self.logger.warning(f"return code from agent {self.flare_agent_fqcn} for bye: {rc}") + + def _maintain(self): + # try to connect the flare agent + self.logger.info(f"waiting for flare agent {self.flare_agent_fqcn} ...") + assert isinstance(self.cell, Cell) + start_time = time.time() + while not self.is_done: + self.logger.info(f"ping {self.flare_agent_fqcn}") + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_HELLO, + target=self.flare_agent_fqcn, + request=new_message(), + timeout=2.0, + optional=True, + ) + + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == CellReturnCode.OK: + self.logger.info(f"connected to agent {self.flare_agent_fqcn}") + self.agent_ready_waiter.set() + break + + self.logger.info(f"get reply: {reply.headers}") + if time.time() - start_time > self.agent_ready_timeout: + # cannot connect to agent! + with self.engine.new_context() as fl_ctx: + self.system_panic( + reason=f"cannot connect to agent {self.flare_agent_fqcn} after {self.agent_ready_timeout} secs", + fl_ctx=fl_ctx, + ) + self.is_done = True + return + time.sleep(2.0) + + # agent is now connected - heartbeats + last_hb_time = 0 + hb_interval = 10.0 + while True: + if self.is_done: + return + + if time.time() - last_hb_time > hb_interval: + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_HEARTBEAT, + target=self.flare_agent_fqcn, + request=new_message(), + timeout=1.5, + ) + last_hb_time = time.time() + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == CellReturnCode.OK: + self.last_agent_ack_time = time.time() + + if time.time() - self.last_agent_ack_time > self.agent_heartbeat_timeout: + with self.engine.new_context() as fl_ctx: + self.system_panic( + reason=f"agent dead: no heartbeat for {self.agent_heartbeat_timeout} secs", + fl_ctx=fl_ctx, + ) + self.is_done = True + return + + # sleep only small amount of time, so we can check other conditions frequently + time.sleep(0.2) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + # wait for flare agent + while True: + if self.is_done or abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # wait for agent to be ready + # we only wait for short time, so we could check other conditions (is_done, abort_signal) + if self.agent_ready_waiter.wait(0.5): + break + + task_id = shareable.get_header(key=FLContextKey.TASK_ID) + current_task = self.task_ctx + if current_task: + # still working on previous task! + self.log_error(fl_ctx, f"got new task {task_name=} {task_id=} while still working on {current_task}") + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + self.task_ctx = _TaskContext(task_name, task_id, fl_ctx) + result = self._do_execute(task_name, shareable, fl_ctx, abort_signal) + self.task_ctx = None + return result + + def _send_task(self, task_ctx: _TaskContext, msg, abort_signal): + # keep sending until done + fl_ctx = task_ctx.fl_ctx + task_name = task_ctx.task_name + task_id = task_ctx.task_id + task_ctx.send_rc = ReturnCode.OK + while True: + if self.is_done or abort_signal.triggered: + self.log_info(fl_ctx, "task aborted - ask agent to abort the task") + + # it's possible that the agent may have already received the task + # we ask it to abort the task. + self._ask_agent_to_abort_task(task_name, task_id) + task_ctx.send_rc = ReturnCode.TASK_ABORTED + return + + if task_ctx.result_received_time: + # the result has been received + # this could happen only when we thought the previous send didn't succeed, but it actually did! + return + + self.log_info(fl_ctx, f"try to send task to {self.flare_agent_fqcn}") + start = time.time() + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_GET_TASK, + request=msg, + target=self.flare_agent_fqcn, + timeout=self.send_task_timeout, + ) + + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + if rc == CellReturnCode.OK: + self.log_info(fl_ctx, f"Sent task to {self.flare_agent_fqcn} in {time.time() - start} secs") + return + elif rc == CellReturnCode.INVALID_REQUEST: + self.log_error(fl_ctx, f"Task rejected by {self.flare_agent_fqcn}: {rc}") + task_ctx.send_rc = ReturnCode.BAD_REQUEST_DATA + return + else: + self.log_error(fl_ctx, f"Failed to send task to {self.flare_agent_fqcn}: {rc}. Will keep trying.") + time.sleep(2.0) + + def _do_execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Unable to extract dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure data kind is weights. + if dxo.data_kind != DataKind.WEIGHTS: + self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # send to flare agent + task_ctx = self.task_ctx + task_id = task_ctx.task_id + data = dxo.data + if not data: + data = {} + meta = dxo.meta + if not meta: + meta = {} + + current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) + total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) + + meta[defs.MetaKey.DATA_KIND] = dxo.data_kind + if current_round is not None: + meta[defs.MetaKey.CURRENT_ROUND] = current_round + if total_rounds is not None: + meta[defs.MetaKey.TOTAL_ROUND] = total_rounds + + msg = new_message( + headers={ + defs.MsgHeader.TASK_ID: task_id, + defs.MsgHeader.TASK_NAME: task_name, + }, + payload={defs.PayloadKey.DATA: data, defs.PayloadKey.META: meta}, + ) + + # keep sending until done + self._send_task(task_ctx, msg, abort_signal) + if task_ctx.send_rc != ReturnCode.OK: + # send_task failed + return make_reply(task_ctx.send_rc) + + # wait for result + self.log_info(fl_ctx, f"Waiting for result from {self.flare_agent_fqcn}") + waiter_timeout = 0.5 + while True: + if task_ctx.result_waiter.wait(timeout=waiter_timeout): + # result available + break + else: + # timed out - check other conditions + if self.is_done or abort_signal.triggered: + self.log_info(fl_ctx, "task is aborted") + + # notify the agent + self._ask_agent_to_abort_task(task_name, task_id) + self.task_ctx = None + return make_reply(ReturnCode.TASK_ABORTED) + + # convert the result + if task_ctx.result_rc != defs.RC.OK: + return make_reply(task_ctx.result_rc) + + result = task_ctx.result + meta = result.get(defs.PayloadKey.META) + data_kind = meta.get(defs.MetaKey.DATA_KIND, DataKind.WEIGHTS) + dxo = DXO( + data_kind=data_kind, + data=result.get(defs.PayloadKey.DATA), + meta=meta, + ) + return dxo.to_shareable() + + def _ask_agent_to_abort_task(self, task_name, task_id): + msg = new_message( + headers={ + defs.MsgHeader.TASK_ID: task_id, + defs.MsgHeader.TASK_NAME: task_name, + } + ) + + self.cell.fire_and_forget( + channel=defs.CHANNEL, + topic=defs.TOPIC_ABORT, + message=msg, + targets=[self.flare_agent_fqcn], + optional=True, + ) + + @staticmethod + def _finish_result(task_ctx: _TaskContext, result_rc="", result=None, result_is_valid=True): + task_ctx.result_rc = result_rc + task_ctx.result = result + task_ctx.result_received_time = time.time() + task_ctx.result_waiter.set() + if result_is_valid: + return make_cell_reply(CellReturnCode.OK) + else: + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + def _receive_result(self, request: Message) -> Union[None, Message]: + sender = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_ctx = self.task_ctx + if not task_ctx: + self.logger.error(f"received result from {sender} for task {task_id} while not waiting for result!") + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + fl_ctx = task_ctx.fl_ctx + if task_id != task_ctx.task_id: + self.log_error(fl_ctx, f"received task id {task_id} != expected {task_ctx.task_id}") + return make_cell_reply(CellReturnCode.INVALID_REQUEST) + + if task_ctx.result_received_time: + # already received - this is a dup + self.log_info(fl_ctx, f"received duplicate result from {sender}") + return make_cell_reply(CellReturnCode.OK) + + payload = request.payload + if not isinstance(payload, dict): + self.log_error(fl_ctx, f"bad result from {sender}: expect dict but got {type(payload)}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + data = payload.get(defs.PayloadKey.DATA) + if data is None: + self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.DATA}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + meta = payload.get(defs.PayloadKey.META) + if meta is None: + self.log_error(fl_ctx, f"bad result from {sender}: missing {defs.PayloadKey.META}") + return self._finish_result(task_ctx, result_is_valid=False, result_rc=ReturnCode.EXECUTION_EXCEPTION) + + self.log_info(fl_ctx, f"received result from {sender}") + return self._finish_result( + task_ctx, + result_is_valid=True, + result_rc=request.get_header(defs.MsgHeader.RC, defs.RC.OK), + result=payload, + ) diff --git a/nvflare/client/__init__.py b/nvflare/client/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/client/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/client/defs.py b/nvflare/client/defs.py new file mode 100644 index 0000000000..e7fa8c146d --- /dev/null +++ b/nvflare/client/defs.py @@ -0,0 +1,149 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nvflare.fuel.f3.cellnet.fqcn import FQCN + +CHANNEL = "flare_agent" + +TOPIC_GET_TASK = "get_task" +TOPIC_SUBMIT_RESULT = "submit_result" +TOPIC_HEARTBEAT = "heartbeat" +TOPIC_HELLO = "hello" +TOPIC_BYE = "bye" +TOPIC_ABORT = "abort" + +JOB_META_KEY_AGENT_ID = "agent_id" + + +class RC: + OK = "OK" + BAD_TASK_DATA = "BAD_TASK_DATA" + EXECUTION_EXCEPTION = "EXECUTION_EXCEPTION" + + +class MsgHeader: + + TASK_ID = "task_id" + TASK_NAME = "task_name" + RC = "rc" + + +class PayloadKey: + DATA = "data" + META = "meta" + + +class MetaKey: + CURRENT_ROUND = "current_round" + TOTAL_ROUND = "total_round" + DATA_KIND = "data_kind" + NUM_STEPS_CURRENT_ROUND = "NUM_STEPS_CURRENT_ROUND" + PROCESSED_ALGORITHM = "PROCESSED_ALGORITHM" + PROCESSED_KEYS = "PROCESSED_KEYS" + INITIAL_METRICS = "initial_metrics" + FILTER_HISTORY = "filter_history" + + +class Task: + def __init__(self, task_name: str, task_id: str, meta: dict, data): + self.task_name = task_name + self.task_id = task_id + self.meta = meta + self.data = data + + def __str__(self): + return f"'{self.task_name} {self.task_id}'" + + +class TaskResult: + def __init__(self, meta: dict, data, return_code=RC.OK): + if not meta: + meta = {} + + if not isinstance(meta, dict): + raise TypeError(f"meta must be dict but got {type(meta)}") + + if not data: + data = {} + + if not isinstance(return_code, str): + raise TypeError(f"return_code must be str but got {type(return_code)}") + + self.return_code = return_code + self.meta = meta + self.data = data + + +class AgentClosed(Exception): + pass + + +class CallStateError(Exception): + pass + + +class FlareAgent(ABC): + @abstractmethod + def start(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def get_task(self, timeout=None): + """Get a task from FLARE. This is a blocking call. + + If timeout is specified, this call is blocked only for the specified amount of time. + If timeout is not specified, this call is blocked forever until a task is received or agent is closed. + + Args: + timeout: amount of time to block + + Returns: None if no task is available during before timeout; or a Task object if task is available. + Raises: + AgentClosed exception if the agent is closed before timeout. + CallStateError exception if the call is not made properly. + + Note: the application must make the call only when it is just started or after a previous task's result + has been submitted. + + """ + pass + + def submit_result(self, result: TaskResult) -> bool: + """Submit the result of the current task. + This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or + the task is aborted or the agent is closed. + + Args: + result: result to be submitted + + Returns: whether the result is submitted successfully + Raises: the CallStateError exception if the submit_result call is not made properly. + + Notes: the application must only make this call after the received task is processed. The call can only be + made a single time regardless whether the submission is successful. + + """ + pass + + +def agent_site_fqcn(site_name: str, agent_id: str, job_id=None): + if not job_id: + return f"{site_name}--{agent_id}" + else: + return FQCN.join([site_name, job_id, agent_id]) diff --git a/nvflare/client/ipc_agent.py b/nvflare/client/ipc_agent.py new file mode 100644 index 0000000000..babaa82485 --- /dev/null +++ b/nvflare/client/ipc_agent.py @@ -0,0 +1,446 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time +import traceback +from typing import Union + +from nvflare.app_common.decomposers import common_decomposers +from nvflare.client import defs +from nvflare.fuel.f3.cellnet.cell import Cell, Message +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.net_agent import NetAgent +from nvflare.fuel.f3.cellnet.utils import make_reply, new_message +from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.utils.config_service import ConfigService + +SSL_ROOT_CERT = "rootCA.pem" + + +class _TaskContext: + + NEW = 0 + FETCHED = 1 + PROCESSED = 2 + + def __init__(self, sender: str, task_name: str, task_id: str, meta: dict, data): + self.sender = sender + self.task_name = task_name + self.task_id = task_id + self.meta = meta + self.data = data + self.status = _TaskContext.NEW + self.last_send_result_time = None + self.aborted = False + self.already_received = False + + def __str__(self): + return f"'{self.task_name} {self.task_id}'" + + +class IPCAgent(defs.FlareAgent): + def __init__( + self, + root_url: str, + flare_site_name: str, + agent_id: str, + workspace_dir: str, + secure_mode=False, + submit_result_timeout=30.0, + flare_site_heartbeat_timeout=60.0, + job_id=None, + flare_site_url=None, + ): + """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) + to get task and to submit task result. + + Args: + root_url: the URL to the server parent cell (SP) + flare_site_name: the CJ's site name (client name) + agent_id: the unique ID of the agent + workspace_dir: directory that contains startup folder and comm_config.json + secure_mode: whether the connection is in secure mode or not + submit_result_timeout: when submitting task result, how long to wait for response from the CJ + flare_site_heartbeat_timeout: max time allowed for missing heartbeats from CJ + job_id: ID of the current Flare Job. Only needed for child-based communication with CJ + flare_site_url: URL for connection to CJ. Only needed for child-based communication with CJ + """ + ConfigService.initialize(section_files={}, config_path=[workspace_dir]) + + self.logger = logging.getLogger(self.__class__.__name__) + self.cell_name = defs.agent_site_fqcn(flare_site_name, agent_id, job_id) + self.workspace_dir = workspace_dir + self.secure_mode = secure_mode + self.root_url = root_url + self.submit_result_timeout = submit_result_timeout + self.flare_site_heartbeat_timeout = flare_site_heartbeat_timeout + self.job_id = job_id + self.flare_site_url = flare_site_url + self.connect_waiter = threading.Event() + self.current_task = None + self.pending_task = None + self.task_lock = threading.Lock() + self.last_hb_time = time.time() + self.is_done = False + self.is_started = False + self.is_stopped = False + self.credentials = {} + + if secure_mode: + root_cert_path = ConfigService.find_file(SSL_ROOT_CERT) + if not root_cert_path: + raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {workspace_dir}") + + self.credentials = { + DriverParams.CA_CERT.value: root_cert_path, + } + + self.cell = Cell( + fqcn=self.cell_name, + root_url=self.root_url, + secure=self.secure_mode, + credentials=self.credentials, + create_internal_listener=False, + parent_url=self.flare_site_url, + ) + self.net_agent = NetAgent(self.cell) + + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_GET_TASK, cb=self._receive_task) + self.logger.info(f"registered task CB for {defs.CHANNEL} {defs.TOPIC_GET_TASK}") + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HELLO, cb=self._handle_hello) + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, cb=self._handle_heartbeat) + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_BYE, cb=self._handle_bye) + self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, cb=self._handle_abort_task) + common_decomposers.register() + + def start(self): + """Start the agent. This method must be called to enable CJ/Agent communication. + + Returns: None + + """ + if self.is_started: + self.logger.warning("the agent is already started") + return + + if self.is_stopped: + raise defs.CallStateError("cannot start the agent since it is already stopped") + + self.is_started = True + self.logger.info(f"starting agent {self.cell_name} ...") + self.cell.start() + t = threading.Thread(target=self._maintain, daemon=True) + t.start() + + def stop(self): + """Stop the agent. After this is called, there will be no more communications between CJ and agent. + + Returns: None + + """ + if not self.is_started: + self.logger.warning("cannot stop the agent since it is not started") + return + + if self.is_stopped: + self.logger.warning("agent is already stopped") + return + + self.is_stopped = True + self.cell.stop() + self.net_agent.close() + + def _maintain(self): + self.logger.info("waiting for flare site to connect ...") + start = time.time() + while True: + if self.connect_waiter.wait(0.5): + # connected! + break + else: + if self.is_done or self.is_stopped: + return + + if time.time() - start > self.flare_site_heartbeat_timeout: + self.logger.error( + f"closing agent {self.cell_name}: flare site not connected " + f"in {self.flare_site_heartbeat_timeout} seconds" + ) + self.is_done = True + return + + while True: + if time.time() - self.last_hb_time > self.flare_site_heartbeat_timeout: + self.logger.error( + f"closing agent {self.cell_name}: no heartbeat from flare site " + f"for {self.flare_site_heartbeat_timeout} seconds" + ) + self.is_done = True + return + time.sleep(1.0) + + def _handle_hello(self, request: Message) -> Union[None, Message]: + self.logger.info(f"got hello: {request.headers}") + sender = request.get_header(MessageHeaderKey.ORIGIN) + self.logger.info(f"connected to the flare site {sender}") + self.last_hb_time = time.time() + self.connect_waiter.set() + return make_reply(ReturnCode.OK) + + def _handle_bye(self, request: Message) -> Union[None, Message]: + sender = request.get_header(MessageHeaderKey.ORIGIN) + self.logger.info(f"got goodbye from {sender}") + self.is_done = True + return make_reply(ReturnCode.OK) + + def _handle_heartbeat(self, request: Message) -> Union[None, Message]: + self.last_hb_time = time.time() + sender = request.get_header(MessageHeaderKey.ORIGIN) + self.logger.info(f"got heartbeat from {sender}") + return make_reply(ReturnCode.OK) + + def _handle_abort_task(self, request: Message) -> Union[None, Message]: + sender = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + self.logger.warning(f"received from {sender} to abort {task_name=} {task_id=}") + with self.task_lock: + if self.current_task and task_id == self.current_task.task_id: + self.current_task.aborted = True + elif self.pending_task and task_id == self.pending_task.task_id: + self.pending_task = None + return make_reply(ReturnCode.OK) + + def _receive_task(self, request: Message) -> Union[None, Message]: + self.logger.info("receiving task ...") + with self.task_lock: + return self._do_receive_task(request) + + def _create_task(self, request: Message): + sender = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + + task_data = request.payload + if not isinstance(task_data, dict): + self.logger.error(f"bad task data from {sender}: expect dict but got {type(task_data)}") + return None + + data = task_data.get(defs.PayloadKey.DATA) + if not data: + self.logger.error(f"bad task data from {sender}: missing {defs.PayloadKey.DATA}") + return None + + meta = task_data.get(defs.PayloadKey.META) + if not meta: + self.logger.error(f"bad task data from {sender}: missing {defs.PayloadKey.META}") + return None + + return _TaskContext(sender, task_name, task_id, meta, data) + + def _do_receive_task(self, request: Message) -> Union[None, Message]: + sender = request.get_header(MessageHeaderKey.ORIGIN) + task_id = request.get_header(defs.MsgHeader.TASK_ID) + task_name = request.get_header(defs.MsgHeader.TASK_NAME) + self.logger.info(f"_do_receive_task from {sender}: {task_name=} {task_id=}") + + if self.pending_task: + if task_id == self.pending_task.task_id: + return make_reply(ReturnCode.OK) + else: + self.logger.error("got a new task while already have a pending task!") + return make_reply(ReturnCode.INVALID_REQUEST) + + current_task = self.current_task + if current_task: + assert isinstance(current_task, _TaskContext) + if task_id == current_task.task_id: + self.logger.info(f"received duplicate task {task_id} from {sender}") + return make_reply(ReturnCode.OK) + + if current_task.last_send_result_time: + # we already tried to send result back + # assume that the flare site has received + # we set the flag so the sending process will end quickly + # in the meanwhile we ask flare site to retry later + current_task.already_received = True + self.pending_task = self._create_task(request) + if self.pending_task: + return make_reply(ReturnCode.OK) + else: + return make_reply(ReturnCode.INVALID_REQUEST) + else: + # error - one task at a time + self.logger.error( + f"got task {task_name} {task_id} from {sender} " + f"while still working on {current_task.task_name} {current_task.task_id}" + ) + return make_reply(ReturnCode.INVALID_REQUEST) + + self.current_task = self._create_task(request) + if self.current_task: + return make_reply(ReturnCode.OK) + else: + return make_reply(ReturnCode.INVALID_REQUEST) + + def get_task(self, timeout=None): + """Get a task from FLARE. This is a blocking call. + + If timeout is specified, this call is blocked only for the specified amount of time. + If timeout is not specified, this call is blocked forever until a task is received or agent is closed. + + Args: + timeout: amount of time to block + + Returns: None if no task is available during before timeout; or a Task object if task is available. + Raises: + AgentClosed exception if the agent is closed before timeout. + CallStateError exception if the call is not made properly. + + Note: the application must make the call only when it is just started or after a previous task's result + has been submitted. + + """ + if timeout is not None: + if not isinstance(timeout, (int, float)): + raise TypeError(f"timeout must be (int, float) but got {type(timeout)}") + if timeout <= 0: + raise ValueError(f"timeout must > 0, but got {timeout}") + + start = time.time() + while True: + if self.is_done or self.is_stopped: + self.logger.info("no more tasks - agent closed") + raise defs.AgentClosed("flare agent is closed") + + with self.task_lock: + current_task = self.current_task + if current_task: + assert isinstance(current_task, _TaskContext) + if current_task.aborted: + pass + elif current_task.status == _TaskContext.NEW: + current_task.status = _TaskContext.FETCHED + return defs.Task( + current_task.task_name, current_task.task_id, current_task.meta, current_task.data + ) + else: + raise defs.CallStateError( + f"application called get_task while the current task is in status {current_task.status}" + ) + if timeout and time.time() - start > timeout: + # no task available before timeout + self.logger.info(f"get_task timeout after {timeout} seconds") + return None + time.sleep(0.5) + + def submit_result(self, result: defs.TaskResult) -> bool: + """Submit the result of the current task. + This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or + the task is aborted or the agent is closed. + + Args: + result: result to be submitted + + Returns: whether the result is submitted successfully + Raises: the CallStateError exception if the submit_result call is not made properly. + + Notes: the application must only make this call after the received task is processed. The call can only be + made a single time regardless whether the submission is successful. + + """ + if not isinstance(result, defs.TaskResult): + raise TypeError(f"result must be TaskResult but got {type(result)}") + + with self.task_lock: + current_task = self.current_task + if not current_task: + self.logger.error("submit_result is called but there is no current task!") + return False + + assert isinstance(current_task, _TaskContext) + if current_task.aborted: + return False + if current_task.status != _TaskContext.FETCHED: + raise defs.CallStateError( + f"submit_result is called while current task is in status {current_task.status}" + ) + current_task.status = _TaskContext.PROCESSED + try: + result = self._do_submit_result(current_task, result) + except: + self.logger.error(f"exception submitting result to {current_task.sender}") + traceback.print_exc() + result = False + + with self.task_lock: + self.current_task = None + if self.pending_task: + # a new task is waiting for the current task to finish + self.current_task = self.pending_task + self.pending_task = None + return result + + def _do_submit_result(self, current_task: _TaskContext, result: defs.TaskResult): + meta = result.meta + rc = result.return_code + data = result.data + + msg = new_message( + headers={ + defs.MsgHeader.TASK_NAME: current_task.task_name, + defs.MsgHeader.TASK_ID: current_task.task_id, + defs.MsgHeader.RC: rc, + }, + payload={ + defs.PayloadKey.META: meta, + defs.PayloadKey.DATA: data, + }, + ) + while True: + if current_task.already_received: + if not current_task.last_send_result_time: + self.logger.warning(f"task {current_task} was marked already_received but has been sent!") + return True + + if self.is_done or self.is_stopped: + self.logger.error(f"quit submitting result for task {current_task} since agent is closed") + return False + + if current_task.aborted: + self.logger.error(f"quit submitting result for task {current_task} since it is aborted") + return False + + current_task.last_send_result_time = time.time() + self.logger.info(f"sending result to {current_task.sender} for task {current_task}") + reply = self.cell.send_request( + channel=defs.CHANNEL, + topic=defs.TOPIC_SUBMIT_RESULT, + target=current_task.sender, + request=msg, + timeout=self.submit_result_timeout, + ) + if reply: + rc = reply.get_header(MessageHeaderKey.RETURN_CODE) + sender = reply.get_header(MessageHeaderKey.ORIGIN) + if rc == ReturnCode.OK: + return True + elif rc == ReturnCode.INVALID_REQUEST: + self.logger.error(f"received return code from {sender}: {rc}") + return False + else: + self.logger.info(f"failed to send to {current_task.sender}: {rc} - will retry") + time.sleep(2.0) diff --git a/nvflare/fuel/f3/cellnet/cell.py b/nvflare/fuel/f3/cellnet/cell.py index 404dea855d..94c1fe50d3 100644 --- a/nvflare/fuel/f3/cellnet/cell.py +++ b/nvflare/fuel/f3/cellnet/cell.py @@ -42,6 +42,7 @@ from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState from nvflare.fuel.f3.message import Message from nvflare.fuel.f3.mpm import MainProcessMonitor @@ -363,6 +364,9 @@ def __init__( self.logger.debug(f"Creating Cell: {self.my_info.fqcn}") + if credentials: + enhance_credential_info(credentials) + ep = Endpoint( name=fqcn, conn_props=credentials, diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 0431519efa..45171ddfb4 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -133,9 +133,10 @@ def is_adhoc_allowed(self, c1: FqcnInfo, c2: FqcnInfo) -> bool: return False # we only allow gen2 (or above) cells to directly connect - if c1.gen >= 2 and c2.gen >= 2: - return True - return False + # if c1.gen >= 2 and c2.gen >= 2: + # return True + # return False + return True @staticmethod def _validate_conn_config(config: dict, key: str) -> Union[None, dict]: diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 3f9ae1e77d..377dfd783d 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import os.path import random import socket import ssl @@ -38,6 +39,12 @@ MAX_HEADER_SIZE = 1024 * 1024 MAX_PAYLOAD_SIZE = MAX_FRAME_SIZE - 16 - MAX_HEADER_SIZE +SSL_SERVER_PRIVATE_KEY = "server.key" +SSL_SERVER_CERT = "server.crt" +SSL_CLIENT_PRIVATE_KEY = "client.key" +SSL_CLIENT_CERT = "client.crt" +SSL_ROOT_CERT = "rootCA.pem" + def ssl_required(params: dict) -> bool: """Check if SSL is required""" @@ -258,3 +265,41 @@ def get_tcp_urls(scheme: str, resources: dict) -> (str, str): connect_url = f"{scheme}://{host}:{port}" return connect_url, listening_url + + +def enhance_credential_info(params: dict): + # must have CA + ca_path = params.get(DriverParams.CA_CERT.value) + if not ca_path: + return params + + # assume all SSL credential files are in the same folder with CA cert + cred_folder = os.path.dirname(ca_path) + + client_cert_path = params.get(DriverParams.CLIENT_CERT.value) + if not client_cert_path: + # see whether the file client cert file exists + client_cert_path = os.path.join(cred_folder, SSL_CLIENT_CERT) + if os.path.exists(client_cert_path): + params[DriverParams.CLIENT_CERT.value] = client_cert_path + + client_key_path = params.get(DriverParams.CLIENT_KEY.value) + if not client_key_path: + # see whether the file client key file exists + client_key_path = os.path.join(cred_folder, SSL_CLIENT_PRIVATE_KEY) + if os.path.exists(client_key_path): + params[DriverParams.CLIENT_KEY.value] = client_key_path + + server_cert_path = params.get(DriverParams.SERVER_CERT.value) + if not server_cert_path: + # see whether the file client cert file exists + server_cert_path = os.path.join(cred_folder, SSL_SERVER_CERT) + if os.path.exists(server_cert_path): + params[DriverParams.SERVER_CERT.value] = server_cert_path + + server_key_path = params.get(DriverParams.SERVER_KEY.value) + if not server_key_path: + # see whether the file client key file exists + server_key_path = os.path.join(cred_folder, SSL_SERVER_PRIVATE_KEY) + if os.path.exists(server_key_path): + params[DriverParams.SERVER_KEY.value] = server_key_path diff --git a/nvflare/fuel/f3/sfm/conn_manager.py b/nvflare/fuel/f3/sfm/conn_manager.py index 4386cfc2b9..8f812e18a2 100644 --- a/nvflare/fuel/f3/sfm/conn_manager.py +++ b/nvflare/fuel/f3/sfm/conn_manager.py @@ -285,7 +285,17 @@ def start_connector_task(connector: ConnectorInfo): else: log.info(reconnect_msg) - time.sleep(wait) + # Do not sleep(wait) since 'wait' could be long and the connector.stopping could be set while sleep. + # To return as soon as connector.stopping is set, we need to check it frequently. + # We sleep very short period of time (0.1 sec). + wait_start = time.time() + while True: + if connector.stopping: + return + elif time.time() - wait_start >= wait: + break + time.sleep(0.1) + # Exponential backoff wait *= 2 if wait > MAX_WAIT: