Skip to content

Commit

Permalink
Add he support to pt params converter (#2238)
Browse files Browse the repository at this point in the history
* Add he support to pt params converter

* change to path, add condition

---------

Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
SYangster and chesterxgchen authored Dec 21, 2023
1 parent 2ad8da2 commit f4bd1b4
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 1 deletion.
136 changes: 136 additions & 0 deletions job_templates/sag_pt_he/config_fed_client.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

args {
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "pytorch"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true
}
}
}
],

task_data_filters = [
{
tasks = ["train"]
filters = [
{
path = "nvflare.app_opt.he.model_decryptor.HEModelDecryptor"
args {
}
}
]
}
]
task_result_filters = [
{
tasks = ["train"]
filters = [
{
path = "nvflare.app_opt.he.model_encryptor.HEModelEncryptor"
args {
weigh_by_local_iter = true
}
}
]
},
]

components = [
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}
117 changes: 117 additions & 0 deletions job_templates/sag_pt_he/config_fed_server.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
{
# version of the configuration
format_version = 2

# task data filter: if filters are provided, the filter will filter the data flow out of server to client.
task_data_filters =[]

# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
task_result_filters = []

# This assumes that there will be a "net.py" file with class name "Net".
# If your model code is not in "net.py" and class name is not "Net", please modify here
model_class_path = "net.Net"

# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
workflows = [
{
# 1st workflow"
id = "scatter_and_gather"

# name = ScatterAndGather, path is the class path of the ScatterAndGather controller.
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather"
args {
# argument of the ScatterAndGather class.
# min number of clients required for ScatterAndGather controller to move to the next round
# during the workflow cycle. The controller will wait until the min_clients returned from clients
# before move to the next step.
min_clients = 2

# number of global round of the training.
num_rounds = 2

# starting round is 0-based
start_round = 0

# after received min number of clients' result,
# how much time should we wait further before move to the next step
wait_time_after_min_received = 0

# For ScatterAndGather, the server will aggregate the weights based on the client's result.
# the aggregator component id is named here. One can use the this ID to find the corresponding
# aggregator component listed below
aggregator_id = "aggregator"

# The Scatter and Gather controller use an persistor to load the model and save the model.
# The persistent component can be identified by component ID specified here.
persistor_id = "persistor"

# Shareable to a communication message, i.e. shared between clients and server.
# Shareable generator is a component that responsible to take the model convert to/from this communication message: Shareable.
# The component can be identified via "shareable_generator_id"
shareable_generator_id = "shareable_generator"

# train task name: client side needs to have an executor that handles this task
train_task_name = "train"

# train timeout in second. If zero, meaning no timeout.
train_timeout = 0
}
}
]

# List of components used in the server side workflow.
components = [
{
# This is the persistence component used in above workflow.
# PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file.

id = "persistor"
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor"

# the persitor class take model class as argument
# This imply that the model is initialized from the server-side.
# The initialized model will be broadcast to all the clients to start the training.
args {
model {
path = "{model_class_path}"
}
filter_id = "serialize_filter"
}
},
{
id = "shareable_generator"
path = "nvflare.app_opt.he.model_shareable_generator.HEModelShareableGenerator"
args {}
}
{
id = "aggregator"
path = "nvflare.app_opt.he.intime_accumulate_model_aggregator.HEInTimeAccumulateWeightedAggregator"
args {
weigh_by_local_iter = false
expected_data_kind = "WEIGHT_DIFF"
}
}
{
id = "serialize_filter"
path = "nvflare.app_opt.he.model_serialize_filter.HEModelSerializeFilter"
args {
}
}
{
# This component is not directly used in Workflow.
# it select the best model based on the incoming global validation metrics.
id = "model_selector"
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector"
# need to make sure this "key_metric" match what server side received
args.key_metric = "accuracy"
},
{
id = "receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args.events = ["fed.analytix_log_stats"]
}
]

}
5 changes: 5 additions & 0 deletions job_templates/sag_pt_he/info.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
description = "scatter & gather workflow using pytorch and homomorphic encryption"
client_category = "client_api"
controller_type = "server"
}
11 changes: 11 additions & 0 deletions job_templates/sag_pt_he/info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Job Template Information Card

## sag_pt_he
name = "sag_pt_he"
description = "Scatter and Gather Workflow using pytorch and homomorphic encryption"
class_name = "ScatterAndGather"
controller_type = "server"
executor_type = "launcher_executor"
contributor = "NVIDIA"
init_publish_date = "2023-12-20"
last_updated_date = "2023-12-20" # yyyy-mm-dd
10 changes: 10 additions & 0 deletions job_templates/sag_pt_he/meta.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
name = "sag_pt_he"
resource_spec = {}
deploy_map {
# change deploy map as needed.
app = ["@ALL"]
}
min_clients = 2
mandatory_clients = []
}
11 changes: 10 additions & 1 deletion nvflare/app_opt/pt/params_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,25 @@

from typing import Dict

import numpy as np
import torch

from nvflare.app_common.abstract.params_converter import ParamsConverter


class NumpyToPTParamsConverter(ParamsConverter):
def convert(self, params: Dict, fl_ctx) -> Dict:
return {k: torch.as_tensor(v) for k, v in params.items()}
tensor_shapes = fl_ctx.get_prop("tensor_shapes")
if tensor_shapes:
return {
k: torch.as_tensor(np.reshape(v, tensor_shapes[k])) if k in tensor_shapes else torch.as_tensor(v)
for k, v in params.items()
}
else:
return {k: torch.as_tensor(v) for k, v in params.items()}


class PTToNumpyParamsConverter(ParamsConverter):
def convert(self, params: Dict, fl_ctx) -> Dict:
fl_ctx.set_prop("tensor_shapes", {k: v.shape for k, v in params.items()})
return {k: v.cpu().numpy() for k, v in params.items()}

0 comments on commit f4bd1b4

Please sign in to comment.