Skip to content

Commit

Permalink
Add metrics exchange
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Dec 2, 2023
1 parent 2663882 commit a979853
Show file tree
Hide file tree
Showing 18 changed files with 508 additions and 180 deletions.
19 changes: 19 additions & 0 deletions examples/hello-world/ml-to-fl/np/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,22 @@ Then we can run it using the NVFlare Simulator:
```bash
nvflare simulator -n 2 -t 2 ./jobs/np_loop_cell_pipe -w np_loop_cell_pipe_workspace
```

## Launch once for the whole job and with metrics streaming

Sometimes we want to stream the training progress.
We add flare.log to [./code/train_loop.py](./code/train_loop.py)

Then we can create the job:

```bash
nvflare job create -force -j ./jobs/np_metrics -w sag_np_cell_pipe -sd ./code/ \
-f config_fed_client.conf app_script=train_metrics.py params_transfer_type=DIFF launch_once=true \
-f config_fed_server.conf expected_data_kind=WEIGHT_DIFF
```

Then we can run it using the NVFlare Simulator:

```bash
nvflare simulator -n 2 -t 2 ./jobs/np_metrics -w np_metrics_workspace
```
85 changes: 85 additions & 0 deletions examples/hello-world/ml-to-fl/np/code/train_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import time

import nvflare.client as flare
from nvflare.app_common.metrics_exchange.metrics_exchanger import IPCMetricsExchanger


def train(input_arr, current_round, epochs=3):
metrics_exchanger = IPCMetricsExchanger()
output_arr = copy.deepcopy(input_arr)
num_of_data = 2000
batch_size = 16
num_of_batches = num_of_data // batch_size
for i in range(epochs):
for j in range(num_of_batches):
metrics_exchanger.log(
key="loss_for_each_batch",
value=current_round * num_of_batches * epochs + i * num_of_batches + j,
data_type=flare.AnalyticsDataType.SCALAR,
global_step=current_round * num_of_batches * epochs + i * num_of_batches + j,
)
# mock training with plus 1
output_arr += 1
# assume each epoch takes 1 seconds
time.sleep(1.0)
return output_arr


def evaluate(input_arr):
# mock evaluation metrics
return 100


def main():
# initializes NVFlare interface
flare.init()

# get system information
sys_info = flare.system_info()
print(f"system info is: {sys_info}")

while flare.is_running():
input_model = flare.receive()
print(f"received weights is: {input_model.params}")

input_numpy_array = input_model.params["numpy_key"]

# training
output_numpy_array = train(input_numpy_array, current_round=input_model.current_round, epochs=3)

# evaluation
metrics = evaluate(input_numpy_array)

sys_info = flare.system_info()
print(f"system info is: {sys_info}")
print(f"finish round: {input_model.current_round}")

# send back the model
print(f"send back: {output_numpy_array}")
flare.send(
flare.FLModel(
params={"numpy_key": output_numpy_array},
params_type="FULL",
metrics={"accuracy": metrics},
current_round=input_model.current_round,
)
)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions job_templates/sag_np/config_fed_server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "tb_analytics_receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args.events = ["fed.analytix_log_stats"]
}
]

Expand Down
27 changes: 27 additions & 0 deletions job_templates/sag_np_cell_pipe/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,33 @@
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
},
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metrics_retriever"
path = "nvflare.app_common.metrics_exchange.metrics_retriever.MetricsRetriever"
args {
pipe_id = "metrics_pipe"
event_type = "analytix_log_stats"
}
},
{
id = "event_to_fed"
path = "nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent"
args {
events_to_convert = ["analytix_log_stats"]
fed_event_prefix = "fed."
}
}
]
Expand Down
5 changes: 5 additions & 0 deletions job_templates/sag_np_cell_pipe/config_fed_server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "tb_analytics_receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args.events = ["fed.analytix_log_stats"]
}
]

Expand Down
13 changes: 13 additions & 0 deletions job_templates/sag_pt/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
},
{
id = "analytic_sender"
path = "nvflare.app_common.widgets.streaming.AnalyticsSender"
args.event_type = "analytix_log_stats"
},
{
id = "event_to_fed"
path = "nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent"
args {
events_to_convert = ["analytix_log_stats"]
fed_event_prefix = "fed."
}
}
{
id = "pipe"
Expand Down
44 changes: 44 additions & 0 deletions nvflare/app_common/metrics_exchange/client_api_metric_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.data_exchange.piper import Piper
from nvflare.client.config import ClientConfig, ConfigKey
from nvflare.client.constants import CONFIG_METRICS_EXCHANGE
from .metric_receiver import MetricReceiver


class ClientAPIMetricReceiver(MetricReceiver):
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
super().handle_event(event_type, fl_ctx)
self.prepare_external_config(fl_ctx)

def prepare_external_config(self, fl_ctx: FLContext):
workspace = fl_ctx.get_engine().get_workspace()
app_dir = workspace.get_app_dir(fl_ctx.get_job_id())
config_file = os.path.join(app_dir, workspace.config_folder, CONFIG_METRICS_EXCHANGE)

# prepare config exchange for data exchanger
client_config = ClientConfig()
config_dict = client_config.config
config_dict[ConfigKey.PIPE_CHANNEL_NAME] = self.pipe_channel_name
config_dict[ConfigKey.PIPE_CLASS] = Piper.get_external_pipe_class(self.pipe, fl_ctx)
config_dict[ConfigKey.PIPE_ARGS] = Piper.get_external_pipe_args(self.pipe, fl_ctx)
config_dict[ConfigKey.SITE_NAME] = fl_ctx.get_identity_name()
config_dict[ConfigKey.JOB_ID] = fl_ctx.get_job_id()
client_config.to_json(config_file)
88 changes: 88 additions & 0 deletions nvflare/app_common/metrics_exchange/memory_metrics_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from queue import Queue

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.metrics_exchange.metrics_exchanger import MemoryMetricsExchanger
from nvflare.app_common.tracking.tracker_types import LogWriterName
from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.memory_pipe import MemoryPipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler

from .metric_receiver import MetricReceiver


class MemoryMetricReceiver(MetricReceiver):
def __init__(
self,
metrics_exchanger_id: str,
event_type=ANALYTIC_EVENT_TYPE,
writer_name=LogWriterName.TORCH_TB,
topic: str = "metrics",
get_poll_interval: float = 0.5,
read_interval: float = 0.1,
heartbeat_interval: float = 5.0,
heartbeat_timeout: float = 30.0,
):
"""Metrics receiver with memory pipe.
Args:
event_type (str): event type to fire (defaults to "analytix_log_stats").
writer_name: the log writer for syntax information (defaults to LogWriterName.TORCH_TB)
"""
super().__init__(
event_type=event_type,
writer_name=writer_name,
topic=topic,
get_poll_interval=get_poll_interval,
read_interval=read_interval,
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
)
self.metrics_exchanger_id = metrics_exchanger_id

self.x_queue = Queue()
self.y_queue = Queue()

def _init_pipe(self, fl_ctx: FLContext) -> None:
self._pipe = MemoryPipe(x_queue=self.x_queue, y_queue=self.y_queue, mode=Mode.PASSIVE)

def _create_metrics_exchanger(self):
pipe = MemoryPipe(x_queue=self.x_queue, y_queue=self.y_queue, mode=Mode.ACTIVE)
pipe.open(name=self._pipe_name)
# init pipe handler
pipe_handler = PipeHandler(
pipe,
read_interval=self._read_interval,
heartbeat_interval=self._heartbeat_interval,
heartbeat_timeout=self._heartbeat_timeout,
)
pipe_handler.start()
metrics_exchanger = MemoryMetricsExchanger(pipe_handler=pipe_handler)
return metrics_exchanger

def handle_event(self, event_type: str, fl_ctx: FLContext):
super().handle_event(event_type, fl_ctx)
if event_type == EventType.ABOUT_TO_START_RUN:
engine = fl_ctx.get_engine()
# inserts MetricsExchanger into engine components
metrics_exchanger = self._create_metrics_exchanger()
all_components = engine.get_all_components()
all_components[self.metrics_exchanger_id] = metrics_exchanger

def prepare_external_config(self, fl_ctx: FLContext):
pass
78 changes: 78 additions & 0 deletions nvflare/app_common/metrics_exchange/metric_receiver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.dxo import DXO
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.fuel.utils.constants import PipeChannelName
from nvflare.fuel.utils.pipe.pipe import Message, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler
from nvflare.apis.fl_component import FLComponent
from nvflare.app_common.tracking.analytic_utils import send_analytic_dxo


class MetricReceiver(FLComponent):
def __init__(
self,
pipe_id: str,
read_interval=0.1,
heartbeat_interval=5.0,
heartbeat_timeout=30.0,
pipe_channel_name=PipeChannelName.METRIC,
):
super().__init__()
self.pipe_id = pipe_id
self.read_interval = read_interval
self.heartbeat_interval = heartbeat_interval
self.heartbeat_timeout = heartbeat_timeout
self.pipe_channel_name = pipe_channel_name
self.pipe = None
self.pipe_handler = None
self._fl_ctx = None

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
engine = fl_ctx.get_engine()
pipe = engine.get_component(self.pipe_id)
if not isinstance(pipe, Pipe):
self.log_error(fl_ctx, f"component {self.pipe_id} must be Pipe but got {type(pipe)}")
self.system_panic(f"bad component {self.pipe_id}", fl_ctx)
return
self._fl_ctx = fl_ctx
self.pipe = pipe
self.pipe_handler = PipeHandler(
pipe=self.pipe,
read_interval=self.read_interval,
heartbeat_interval=self.heartbeat_interval,
heartbeat_timeout=self.heartbeat_timeout,
)
self.pipe_handler.set_status_cb(self._pipe_status_cb)
self.pipe_handler.set_message_cb(self._pipe_msg_cb)
self.pipe.open(self.pipe_channel_name)
self.pipe_handler.start()
elif event_type == EventType.END_RUN:
self.log_info(fl_ctx, "Stopping pipe handler")
if self.pipe_handler:
self.pipe_handler.notify_end("end_of_job")
self.pipe_handler.stop()

def _pipe_status_cb(self, msg: Message):
self.logger.info(f"{self.pipe_channel_name} pipe status changed to {msg.topic}")
self.pipe_handler.stop()

def _pipe_msg_cb(self, msg: Message):
if not isinstance(msg.data, DXO):
self.logger.error(f"bad metric data: expect DXO but got {type(msg.data)}")
self.logger.info(f"received metric record: {msg.topic}: {msg.data.data}")
send_analytic_dxo(self, msg.data, self._fl_ctx)
Loading

0 comments on commit a979853

Please sign in to comment.