Skip to content

Commit

Permalink
Add support of just doing metrics streaming with client api (#2763)
Browse files Browse the repository at this point in the history
* Add support of just doing metrics streaming with client api

* Address review comments
  • Loading branch information
YuanTingHsieh authored Aug 27, 2024
1 parent 7b01b0f commit e956fea
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 35 deletions.
4 changes: 3 additions & 1 deletion nvflare/app_common/widgets/metric_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
pipe_id: str,
read_interval=0.1,
heartbeat_interval=5.0,
heartbeat_timeout=60.0,
pipe_channel_name=PipeChannelName.METRIC,
event_type: str = ANALYTIC_EVENT_TYPE,
fed_event: bool = True,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.pipe_id = pipe_id
self._read_interval = read_interval
self._heartbeat_interval = heartbeat_interval
self._heartbeat_timeout = heartbeat_timeout
self.pipe_channel_name = pipe_channel_name
self.pipe = None
self.pipe_handler = None
Expand All @@ -62,7 +64,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
pipe=self.pipe,
read_interval=self._read_interval,
heartbeat_interval=self._heartbeat_interval,
heartbeat_timeout=0,
heartbeat_timeout=self._heartbeat_timeout,
)
self.pipe_handler.set_status_cb(self._pipe_status_cb)
self.pipe_handler.set_message_cb(self._pipe_msg_cb)
Expand Down
14 changes: 9 additions & 5 deletions nvflare/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from enum import Enum
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -45,12 +47,14 @@ def init(rank: Optional[str] = None):
api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value)
api_type = ClientAPIType(api_type_name)
global client_api
if api_type == ClientAPIType.IN_PROCESS_API:
client_api = data_bus.get_data(CLIENT_API_KEY)
if client_api is None:
if api_type == ClientAPIType.IN_PROCESS_API:
client_api = data_bus.get_data(CLIENT_API_KEY)
else:
client_api = ExProcessClientAPI()
client_api.init(rank=rank)
else:
client_api = ExProcessClientAPI()

client_api.init(rank=rank)
logging.warning("Warning: called init() more than once. The subsequence calls are ignored")


def receive(timeout: Optional[float] = None) -> Optional[FLModel]:
Expand Down
9 changes: 5 additions & 4 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ConfigKey:
TASK_NAME = "TASK_NAME"
TASK_EXCHANGE = "TASK_EXCHANGE"
METRICS_EXCHANGE = "METRICS_EXCHANGE"
HEARTBEAT_TIMEOUT = "HEARTBEAT_TIMEOUT"


class ClientConfig:
Expand Down Expand Up @@ -133,19 +134,19 @@ def get_pipe_class(self, section: str) -> str:
return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME]

def get_exchange_format(self) -> str:
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EXCHANGE_FORMAT]
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EXCHANGE_FORMAT, "")

def get_transfer_type(self) -> str:
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL")

def get_train_task(self):
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.TRAIN_TASK_NAME]
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_TASK_NAME, "")

def get_eval_task(self):
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EVAL_TASK_NAME]
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.EVAL_TASK_NAME, "")

def get_submit_model_task(self):
return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.SUBMIT_MODEL_TASK_NAME]
return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.SUBMIT_MODEL_TASK_NAME, "")

def to_json(self, config_file: str):
with open(config_file, "w") as f:
Expand Down
20 changes: 13 additions & 7 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from nvflare.client.model_registry import ModelRegistry
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.import_utils import optional_import
from nvflare.fuel.utils.obj_utils import get_logger
from nvflare.fuel.utils.pipe.pipe import Pipe


def _create_client_config(config: str) -> ClientConfig:
if isinstance(config, str):
client_config = from_file(config_file=config)
else:
raise ValueError("config should be a string but got: {type(config)}")
raise ValueError(f"config should be a string but got: {type(config)}")
return client_config


Expand All @@ -62,6 +63,7 @@ def _register_tensor_decomposer():
class ExProcessClientAPI(APISpec):
def __init__(self):
self.process_model_registry = None
self.logger = get_logger(self)

def get_model_registry(self) -> ModelRegistry:
"""Gets the ModelRegistry."""
Expand All @@ -81,20 +83,23 @@ def init(self, rank: Optional[str] = None):
rank = os.environ.get("RANK", "0")

if self.process_model_registry:
print("Warning: called init() more than once. The subsequence calls are ignored")
self.logger.warning("Warning: called init() more than once. The subsequence calls are ignored")
return

client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}")
config_file = f"config/{CLIENT_API_CONFIG}"
client_config = _create_client_config(config=config_file)

flare_agent = None
try:
if rank == "0":
if client_config.get_exchange_format() == ExchangeFormat.PYTORCH:
_register_tensor_decomposer()

pipe, task_channel_name = _create_pipe_using_config(
client_config=client_config, section=ConfigKey.TASK_EXCHANGE
)
pipe, task_channel_name = None, ""
if ConfigKey.TASK_EXCHANGE in client_config.config:
pipe, task_channel_name = _create_pipe_using_config(
client_config=client_config, section=ConfigKey.TASK_EXCHANGE
)
metric_pipe, metric_channel_name = None, ""
if ConfigKey.METRICS_EXCHANGE in client_config.config:
metric_pipe, metric_channel_name = _create_pipe_using_config(
Expand All @@ -106,12 +111,13 @@ def init(self, rank: Optional[str] = None):
task_channel_name=task_channel_name,
metric_pipe=metric_pipe,
metric_channel_name=metric_channel_name,
heartbeat_timeout=client_config.config.get(ConfigKey.HEARTBEAT_TIMEOUT, 60),
)
flare_agent.start()

self.process_model_registry = ModelRegistry(client_config, rank, flare_agent)
except Exception as e:
print(f"flare.init failed: {e}")
self.logger.error(f"flare.init failed: {e}")
raise e

def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]:
Expand Down
54 changes: 36 additions & 18 deletions nvflare/client/flare_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def __init__(self, task_id, task_name: str, msg_id):
class FlareAgent:
def __init__(
self,
pipe: Pipe,
pipe: Optional[Pipe] = None,
read_interval=0.1,
heartbeat_interval=5.0,
heartbeat_timeout=30.0,
heartbeat_timeout=60.0,
resend_interval=2.0,
max_resends=None,
submit_result_timeout=30.0,
metric_pipe=None,
submit_result_timeout=60.0,
metric_pipe: Optional[Pipe] = None,
task_channel_name: str = PipeChannelName.TASK,
metric_channel_name: str = PipeChannelName.METRIC,
close_pipe: bool = True,
Expand Down Expand Up @@ -103,21 +103,27 @@ def __init__(
Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True.
decomposer_module (str): the module name which contains the external decomposers.
"""
if pipe is None and metric_pipe is None:
raise RuntimeError(
"Please configure at least one pipe. Both the task pipe and the metric pipe are set to None."
)
flare_decomposers.register()
common_decomposers.register()
if decomposer_module:
register_ext_decomposers(decomposer_module)

self.logger = logging.getLogger(self.__class__.__name__)
self.pipe = pipe
self.pipe_handler = PipeHandler(
pipe=self.pipe,
read_interval=read_interval,
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
resend_interval=resend_interval,
max_resends=max_resends,
)
self.pipe_handler = None
if self.pipe:
self.pipe_handler = PipeHandler(
pipe=self.pipe,
read_interval=read_interval,
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
resend_interval=resend_interval,
max_resends=max_resends,
)
self.submit_result_timeout = submit_result_timeout
self.task_channel_name = task_channel_name
self.metric_channel_name = metric_channel_name
Expand Down Expand Up @@ -148,14 +154,17 @@ def start(self):
Returns: None
"""
self.pipe.open(self.task_channel_name)
self.pipe_handler.set_status_cb(self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name)
self.pipe_handler.start()
if self.pipe:
self.pipe.open(self.task_channel_name)
self.pipe_handler.set_status_cb(
self._status_cb, pipe_handler=self.pipe_handler, channel=self.task_channel_name
)
self.pipe_handler.start()

if self.metric_pipe:
self.metric_pipe.open(self.metric_channel_name)
self.metric_pipe_handler.set_status_cb(
self._status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name
self._metrics_status_cb, pipe_handler=self.metric_pipe_handler, channel=self.metric_channel_name
)
self.metric_pipe_handler.start()

Expand All @@ -164,6 +173,11 @@ def _status_cb(self, msg: Message, pipe_handler: PipeHandler, channel):
self.asked_to_stop = True
pipe_handler.stop(self._close_pipe)

def _metrics_status_cb(self, msg: Message, pipe_handler: PipeHandler, channel):
self.logger.info(f"{channel} pipe status changed to {msg.topic}: {msg.data}")
self.asked_to_stop = True
pipe_handler.stop(self._close_metric_pipe)

def stop(self):
"""Stop the agent.
Expand All @@ -172,9 +186,9 @@ def stop(self):
Returns: None
"""
self.logger.info("Calling flare agent stop")
self.asked_to_stop = True
self.pipe_handler.stop(self._close_pipe)
if self.pipe_handler:
self.pipe_handler.stop(self._close_pipe)
if self.metric_pipe_handler:
self.metric_pipe_handler.stop(self._close_metric_pipe)

Expand Down Expand Up @@ -226,6 +240,8 @@ def get_task(self, timeout: Optional[float] = None) -> Optional[Task]:
has been submitted.
"""
if not self.pipe_handler:
raise RuntimeError("task pipe is not available")
start_time = time.time()
while True:
if self.asked_to_stop:
Expand Down Expand Up @@ -278,6 +294,8 @@ def submit_result(self, result, rc=RC.OK) -> bool:
made a single time regardless whether the submission is successful.
"""
if not self.pipe_handler:
raise RuntimeError("task pipe is not available")
with self.task_lock:
current_task = self.current_task
if not current_task:
Expand Down

0 comments on commit e956fea

Please sign in to comment.