diff --git a/examples/README.md b/examples/README.md index b5f482849f..fb531eba34 100644 --- a/examples/README.md +++ b/examples/README.md @@ -22,7 +22,7 @@ The provided examples cover different aspects of [NVIDIA FLARE](https://nvidia.g ## 2. FL algorithms * [Federated Learning with CIFAR-10](./cifar10/README.md) - * Includes examples of using [FedAvg](https://arxiv.org/abs/1602.05629), [FedProx](https://arxiv.org/abs/1812.06127), [FedOpt](https://arxiv.org/abs/2003.00295), [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/), and streaming of TensorBoard metrics to the server during training. + * Includes examples of using [FedAvg](https://arxiv.org/abs/1602.05629), [FedProx](https://arxiv.org/abs/1812.06127), [FedOpt](https://arxiv.org/abs/2003.00295), [SCAFFOLD](https://arxiv.org/abs/1910.06378), [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/), and streaming of TensorBoard metrics to the server during training. ## 3. Medical Image Analysis * [Hello MONAI](./hello-monai/README.md) diff --git a/examples/cifar10/README.md b/examples/cifar10/README.md index 563c17cdba..514e15c8a2 100644 --- a/examples/cifar10/README.md +++ b/examples/cifar10/README.md @@ -123,19 +123,23 @@ for example cat ./workspaces/poc_workspace/server/run_2/cross_site_val/cross_site_val.json ``` -### 3.4: Advanced FL algorithms (FedProx and FedOpt) +### 3.4: Advanced FL algorithms (FedProx, FedOpt, and SCAFFOLD) Next, let's try some different FL algorithms on a more heterogeneous split: -FedProx (https://arxiv.org/abs/1812.06127), adds a regularizer to the CIFAR10Trainer loss (`fedproxloss_mu`)`: +[FedProx](https://arxiv.org/abs/1812.06127) adds a regularizer to the loss used in `CIFAR10Learner` (`fedproxloss_mu`)`: ``` ./run_poc.sh 8 cifar10_fedprox 6 0.1 ``` -FedOpt (https://arxiv.org/abs/2003.00295), uses a new ShareableGenerator to update the global model on the server using a PyTorch optimizer. +[FedOpt](https://arxiv.org/abs/2003.00295) uses a new ShareableGenerator to update the global model on the server using a PyTorch optimizer. Here SGD with momentum and cosine learning rate decay: ``` ./run_poc.sh 8 cifar10_fedopt 7 0.1 ``` +[SCAFFOLD](https://arxiv.org/abs/1910.06378) uses a slightly modified version of the CIFAR-10 Learner implementation, namely the `CIFAR10ScaffoldLearner`, which adds a correction term during local training following the [implementation](https://github.com/Xtra-Computing/NIID-Bench) as described in [Li et al.](https://arxiv.org/abs/2102.02079) +``` +./run_poc.sh 8 cifar10_scaffold 8 0.1 +``` ### 3.5 Secure aggregation using homomorphic encryption @@ -147,7 +151,7 @@ Next we run FedAvg using homomorphic encryption (HE) for secure aggregation on t FedAvg with HE: ``` -./run_secure.sh 8 cifar10_fedavg_he 8 1.0 +./run_secure.sh 8 cifar10_fedavg_he 9 1.0 ``` > **_NOTE:_** Currently, FedOpt is not supported with HE as it would involve running the optimizer on encrypted values. @@ -173,8 +177,8 @@ that HE does not impact the performance accuracy of FedAvg significantly while a | Config | Alpha | Val score | | ----------- | ----------- | ----------- | | cifar10_central | 1.0 | 0.8798 | -| cifar10_fedavg | 1.0 | 0.8873 | -| cifar10_fedavg_he | 1.0 | 0.8864 | +| cifar10_fedavg | 1.0 | 0.8854 | +| cifar10_fedavg_he | 1.0 | 0.8897 | ![Central vs. FedAvg](./figs/central_vs_fedavg_he.png) @@ -185,33 +189,33 @@ This can be observed in the resulting performance of the FedAvg algorithms. | Config | Alpha | Val score | | ----------- | ----------- | ----------- | -| cifar10_fedavg | 1.0 | 0.8873 | -| cifar10_fedavg | 0.5 | 0.8726 | -| cifar10_fedavg | 0.3 | 0.8315 | -| cifar10_fedavg | 0.1 | 0.7726 | +| cifar10_fedavg | 1.0 | 0.8854 | +| cifar10_fedavg | 0.5 | 0.8633 | +| cifar10_fedavg | 0.3 | 0.8350 | +| cifar10_fedavg | 0.1 | 0.7733 | ![Impact of client data heterogeneity](./figs/fedavg_alpha.png) -### 4.3 FedProx vs. FedOpt +### 4.3 FedAvg vs. FedProx vs. FedOpt vs. SCAFFOLD -Finally, we are comparing an `alpha` setting of 0.1, causing a high client data heterogeneity and its -impact on more advanced FL algorithms, namely FedProx and FedOpt. Both achieve a better performance compared to FedAvg -with the same `alpha` setting but FedOpt shows a better convergence rate by utilizing SGD with momentum -to update the global model on the server, and achieves a better performance with the same amount of training steps. +Finally, we compare an `alpha` setting of 0.1, causing a high client data heterogeneity and its +impact on more advanced FL algorithms, namely FedProx, FedOpt, and SCAFFOLD. FedProx and SCAFFOLD achieve better performance compared to FedAvg and FedProx with the same `alpha` setting. However, FedOpt and SCAFFOLD show markedly better convergence rates. SCAFFOLD achieves that by adding a correction term when updating the client models, while FedOpt utilizes SGD with momentum +to update the global model on the server. Both achieve better performance with the same number of training steps as FedAvg/FedProx. | Config | Alpha | Val score | -|------------------| ----------- | ----------- | -| cifar10_fedavg | 0.1 | 0.7726 | -| cifar10_fedprox | 0.1 | 0.7512 | -| cifar10_fedopt | 0.1 | 0.7986 | +|------------------| ----------- | ---------- | +| cifar10_fedavg | 0.1 | 0.7733 | +| cifar10_fedprox | 0.1 | 0.7615 | +| cifar10_fedopt | 0.1 | 0.8013 | +| cifar10_scaffold | 0.1 | 0.8222 | -![FedProx vs. FedOpt](./figs/fedopt_fedprox.png) +![FedProx vs. FedOpt](./figs/fedopt_fedprox_scaffold.png) ## 5. Streaming TensorBoard metrics to the server In a real-world scenario, the researcher won't have access to the TensorBoard events of the individual clients. In order to visualize the training performance in a central place, `AnalyticsSender`, `ConvertToFedEvent` on the client, and `TBAnalyticsReceiver` on the server can be used. For an example using FedAvg and metric streaming during training, run: ``` -./run_poc.sh 8 cifar10_fedavg_stream_tb 9 1.0 +./run_poc.sh 8 cifar10_fedavg_stream_tb 10 1.0 ``` Using this configuration, a `tb_events` folder will be created under the `run_*` folder of the server that includes all the TensorBoard event values of the different clients. diff --git a/examples/cifar10/configs/cifar10_central/config/config_fed_server.json b/examples/cifar10/configs/cifar10_central/config/config_fed_server.json index 5821b58154..49195ac8af 100644 --- a/examples/cifar10/configs/cifar10_central/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_central/config/config_fed_server.json @@ -37,17 +37,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} } ], @@ -72,7 +69,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_fedavg/config/config_fed_server.json b/examples/cifar10/configs/cifar10_fedavg/config/config_fed_server.json index e02429dd8b..9d819b976f 100644 --- a/examples/cifar10/configs/cifar10_fedavg/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_fedavg/config/config_fed_server.json @@ -37,17 +37,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} } ], @@ -72,7 +69,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_fedavg_he/config/config_fed_server.json b/examples/cifar10/configs/cifar10_fedavg_he/config/config_fed_server.json index 8627b0a682..113af6c710 100644 --- a/examples/cifar10/configs/cifar10_fedavg_he/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_fedavg_he/config/config_fed_server.json @@ -37,17 +37,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} } ], @@ -72,7 +69,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_fedavg_stream_tb/config/config_fed_server.json b/examples/cifar10/configs/cifar10_fedavg_stream_tb/config/config_fed_server.json index a4c02f01f4..7f5f34411a 100644 --- a/examples/cifar10/configs/cifar10_fedavg_stream_tb/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_fedavg_stream_tb/config/config_fed_server.json @@ -37,17 +37,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} }, { @@ -77,7 +74,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_fedopt/config/config_fed_server.json b/examples/cifar10/configs/cifar10_fedopt/config/config_fed_server.json index 011fcb46fe..0ee266528b 100644 --- a/examples/cifar10/configs/cifar10_fedopt/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_fedopt/config/config_fed_server.json @@ -56,17 +56,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} } ], @@ -91,7 +88,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_fedprox/config/config_fed_server.json b/examples/cifar10/configs/cifar10_fedprox/config/config_fed_server.json index e02429dd8b..9d819b976f 100644 --- a/examples/cifar10/configs/cifar10_fedprox/config/config_fed_server.json +++ b/examples/cifar10/configs/cifar10_fedprox/config/config_fed_server.json @@ -37,17 +37,14 @@ }, { "id": "model_locator", - "path": "pt.pt_model_locator.PTModelLocator", - "args": {} - }, - { - "id": "formatter", - "path": "pt.pt_formatter.PTFormatter", - "args": {} + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } }, { "id": "json_generator", - "path": "pt.validation_json_generator.ValidationJsonGenerator", + "name": "ValidationJsonGenerator", "args": {} } ], @@ -72,7 +69,6 @@ "name": "CrossSiteModelEval", "args": { "model_locator_id": "model_locator", - "formatter_id": "formatter", "submit_model_timeout": 600, "validation_timeout": 6000, "cleanup_models": true diff --git a/examples/cifar10/configs/cifar10_scaffold/config/config_fed_client.json b/examples/cifar10/configs/cifar10_scaffold/config/config_fed_client.json new file mode 100644 index 0000000000..836e9502b2 --- /dev/null +++ b/examples/cifar10/configs/cifar10_scaffold/config/config_fed_client.json @@ -0,0 +1,37 @@ +{ + "format_version": 2, + + "DATASET_ROOT": "/tmp/cifar10_data", + + "executors": [ + { + "tasks": [ + "train", "submit_model", "validate" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_common.executors.learner_executor.LearnerExecutor", + "args": { + "learner_id": "cifar10-learner" + } + } + } + ], + + "task_result_filters": [ + ], + "task_data_filters": [ + ], + + "components": [ + { + "id": "cifar10-learner", + "path": "pt.learners.cifar10_scaffold_learner.CIFAR10ScaffoldLearner", + "args": { + "dataset_root": "{DATASET_ROOT}", + "aggregation_epochs": 4, + "lr": 1e-2 + } + } + ] +} diff --git a/examples/cifar10/configs/cifar10_scaffold/config/config_fed_server.json b/examples/cifar10/configs/cifar10_scaffold/config/config_fed_server.json new file mode 100644 index 0000000000..05232fab5e --- /dev/null +++ b/examples/cifar10/configs/cifar10_scaffold/config/config_fed_server.json @@ -0,0 +1,83 @@ +{ + "format_version": 2, + + "min_clients": 8, + "num_rounds": 50, + + "server": { + "heart_beat_timeout": 600 + }, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "persistor", + "name": "PTFileModelPersistor", + "args": { + "model": { + "path": "pt.networks.cifar10_nets.ModerateCNN", + "args": {} + } + } + }, + { + "id": "shareable_generator", + "name": "FullModelShareableGenerator", + "args": {} + }, + { + "id": "aggregator", + "name": "InTimeAccumulateWeightedAggregator", + "args": { + "expected_data_kind": { + "_model_weights_": "WEIGHT_DIFF", + "scaffold_c_diff": "WEIGHT_DIFF" + } + } + }, + { + "id": "model_selector", + "name": "IntimeModelSelectionHandler", + "args": {} + }, + { + "id": "model_locator", + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } + }, + { + "id": "json_generator", + "name": "ValidationJsonGenerator", + "args": {} + } + ], + "workflows": [ + { + "id": "scatter_gather_ctl", + "name": "ScatterAndGatherScaffold", + "args": { + "min_clients" : "{min_clients}", + "num_rounds" : "{num_rounds}", + "start_round": 0, + "wait_time_after_min_received": 10, + "aggregator_id": "aggregator", + "persistor_id": "persistor", + "shareable_generator_id": "shareable_generator", + "train_task_name": "train", + "train_timeout": 0 + } + }, + { + "id": "cross_site_model_eval", + "name": "CrossSiteModelEval", + "args": { + "model_locator_id": "model_locator", + "submit_model_timeout": 600, + "validation_timeout": 6000, + "cleanup_models": true + } + } + ] +} diff --git a/examples/cifar10/figs/central_training.png b/examples/cifar10/figs/central_training.png index bbe1d9af29..73fdbaa379 100644 Binary files a/examples/cifar10/figs/central_training.png and b/examples/cifar10/figs/central_training.png differ diff --git a/examples/cifar10/figs/central_vs_fedavg_he.png b/examples/cifar10/figs/central_vs_fedavg_he.png index fce0ccd35c..af7bfc7e21 100644 Binary files a/examples/cifar10/figs/central_vs_fedavg_he.png and b/examples/cifar10/figs/central_vs_fedavg_he.png differ diff --git a/examples/cifar10/figs/fedavg_alpha.png b/examples/cifar10/figs/fedavg_alpha.png index 5064f55d80..72fe68767e 100644 Binary files a/examples/cifar10/figs/fedavg_alpha.png and b/examples/cifar10/figs/fedavg_alpha.png differ diff --git a/examples/cifar10/figs/fedopt_fedprox.png b/examples/cifar10/figs/fedopt_fedprox.png deleted file mode 100644 index fbc754373f..0000000000 Binary files a/examples/cifar10/figs/fedopt_fedprox.png and /dev/null differ diff --git a/examples/cifar10/figs/fedopt_fedprox_scaffold.png b/examples/cifar10/figs/fedopt_fedprox_scaffold.png new file mode 100644 index 0000000000..ff23c567fb Binary files /dev/null and b/examples/cifar10/figs/fedopt_fedprox_scaffold.png differ diff --git a/examples/cifar10/figs/plot_tensorboard_events.py b/examples/cifar10/figs/plot_tensorboard_events.py index 39fd715a4c..ee5b331e76 100644 --- a/examples/cifar10/figs/plot_tensorboard_events.py +++ b/examples/cifar10/figs/plot_tensorboard_events.py @@ -12,34 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os -import glob -import pandas as pd -import tensorflow as tf import matplotlib.pyplot as plt +import pandas as pd import seaborn as sns - +import tensorflow as tf client_results_root = "./workspaces/secure_workspace/site-1" server_results_root = "./workspaces/secure_workspace/localhost" # 4.1 Central vs. FedAvg -experiments = {"cifar10_central": {"run": "run_1", "tag": "val_acc_local_model"}, - "cifar10_fedavg": {"run": "run_2", "tag": "val_acc_global_model"}, - "cifar10_fedavg_he": {"run": "run_8", "tag": "val_acc_global_model"}} +experiments = { + "cifar10_central": {"run": "run_1", "tag": "val_acc_local_model"}, + "cifar10_fedavg": {"run": "run_2", "tag": "val_acc_global_model"}, + "cifar10_fedavg_he": {"run": "run_9", "tag": "val_acc_global_model"}, +} # # 4.2 Impact of client data heterogeneity # experiments = {"cifar10_fedavg (alpha=1.0)": {"run": "run_2", "tag": "val_acc_global_model"}, -# "cifar10_fedavg (alpha=0.5)": {"run": "run_3", "tag": "val_acc_global_model"}, -# "cifar10_fedavg (alpha=0.3)": {"run": "run_4", "tag": "val_acc_global_model"}, -# "cifar10_fedavg (alpha=0.1)": {"run": "run_5", "tag": "val_acc_global_model"}} +# "cifar10_fedavg (alpha=0.5)": {"run": "run_3", "tag": "val_acc_global_model"}, +# "cifar10_fedavg (alpha=0.3)": {"run": "run_4", "tag": "val_acc_global_model"}, +# "cifar10_fedavg (alpha=0.1)": {"run": "run_5", "tag": "val_acc_global_model"}} # -# # 4.3 FedProx vs. FedOpt +# # 4.3 FedProx vs. FedOpt vs. SCAFFOLD # experiments = {"cifar10_fedavg": {"run": "run_5", "tag": "val_acc_global_model"}, # "cifar10_fedprox": {"run": "run_6", "tag": "val_acc_global_model"}, -# "cifar10_fedopt": {"run": "run_7", "tag": "val_acc_global_model"}} +# "cifar10_fedopt": {"run": "run_7", "tag": "val_acc_global_model"}, +# "cifar10_scaffold": {"run": "run_8", "tag": "val_acc_global_model"}} add_cross_site_val = True @@ -71,17 +73,11 @@ def add_eventdata(data, config, filepath, tag="val_acc_global_model"): def main(): - data = { - "Config": [], - "Step": [], - "Accuracy": [] - } + data = {"Config": [], "Step": [], "Accuracy": []} if add_cross_site_val: xsite_keys = ["SRV_server", "SRV_server_best"] - xsite_data = { - "Config": [] - } + xsite_data = {"Config": []} for k in xsite_keys: xsite_data.update({k: []}) else: @@ -97,7 +93,9 @@ def main(): add_eventdata(data, config, eventfile, tag=exp["tag"]) if add_cross_site_val: - xsite_file = glob.glob(os.path.join(server_results_root, exp["run"] + "/**/cross_site_val.json"), recursive=True) + xsite_file = glob.glob( + os.path.join(server_results_root, exp["run"] + "/**/cross_site_val.json"), recursive=True + ) assert len(xsite_file) == 1, "No unique x-site file found!" with open(xsite_file[0], "r") as f: xsite_results = json.load(f) diff --git a/examples/cifar10/pt/learners/cifar10_learner.py b/examples/cifar10/pt/learners/cifar10_learner.py index caa752c791..805391bea1 100644 --- a/examples/cifar10/pt/learners/cifar10_learner.py +++ b/examples/cifar10/pt/learners/cifar10_learner.py @@ -26,14 +26,15 @@ from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable from nvflare.apis.fl_constant import FLContextKey, ReturnCode from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply from nvflare.apis.signal import Signal from nvflare.app_common.abstract.learner_spec import Learner +from nvflare.app_common.abstract.model import ModelLearnableKey from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType from nvflare.app_common.pt.pt_fedproxloss import PTFedProxLoss -class CIFAR10Learner(Learner): +class CIFAR10Learner(Learner): # also supports CIFAR10ScaffoldLearner def __init__( self, dataset_root: str = "./dataset", @@ -52,6 +53,9 @@ def __init__( aggregation_epochs: the number of training epochs for a round. Defaults to 1. train_task_name: name of the task to train the model. submit_model_task_name: name of the task to submit the best local model. + lr: local learning rate. Float number. Defaults to 1e-2. + fedproxloss_mu: weight for FedProx loss. Float number. Defaults to 0.0 (no FedProx). + central: Bool. Whether to simulate central training. Default False. analytic_sender_id: id of `AnalyticsSender` if configured as a client component. If configured, TensorBoard events will be fired. Defaults to "analytic_sender". Returns: @@ -241,13 +245,10 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) - epoch_len = len(self.train_loader) self.log_info(fl_ctx, f"Local steps per epoch: {epoch_len}") - # make a copy of model_global as reference for potential FedProx loss - if self.fedproxloss_mu > 0: - model_global = copy.deepcopy(self.model) - for param in model_global.parameters(): - param.requires_grad = False - else: - model_global = None + # make a copy of model_global as reference for potential FedProx loss or SCAFFOLD + model_global = copy.deepcopy(self.model) + for param in model_global.parameters(): + param.requires_grad = False # local train self.local_train( @@ -334,8 +335,13 @@ def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) - # get round information + # get validation information self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") + model_owner = shareable.get(ReservedHeaderKey.HEADERS).get(AppConstants.MODEL_OWNER) + if model_owner: + self.log_info(fl_ctx, f"Evaluating model from {model_owner} on {fl_ctx.get_identity_name()}") + else: + model_owner = "global_model" # evaluating global model during training # update local model weights with received weights dxo = from_shareable(shareable) @@ -344,15 +350,19 @@ def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. local_var_dict = self.model.state_dict() model_keys = global_weights.keys() + n_loaded = 0 for var_name in local_var_dict: if var_name in model_keys: weights = torch.as_tensor(global_weights[var_name], device=self.device) try: # update the local dict local_var_dict[var_name] = torch.as_tensor(torch.reshape(weights, local_var_dict[var_name].shape)) + n_loaded += 1 except Exception as e: raise ValueError("Convert weight from {} failed with error: {}".format(var_name, str(e))) self.model.load_state_dict(local_var_dict) + if n_loaded == 0: + raise ValueError(f"No weights loaded for validation! Received weight dict is {global_weights}") validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE) if validate_type == ValidateType.BEFORE_TRAIN_VALIDATE: @@ -360,7 +370,7 @@ def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal global_acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_global_model", fl_ctx=fl_ctx) if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) - self.log_info(fl_ctx, f"val_acc_global_model: {global_acc:.4f}") + self.log_info(fl_ctx, f"val_acc_global_model ({model_owner}): {global_acc}") return DXO(data_kind=DataKind.METRICS, data={MetaKey.INITIAL_METRICS: global_acc}, meta={}).to_shareable() @@ -369,12 +379,12 @@ def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal train_acc = self.local_valid(self.train_loader, abort_signal) if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) - self.log_info(fl_ctx, f"training acc: {train_acc:.4f}") + self.log_info(fl_ctx, f"training acc ({model_owner}): {train_acc}") val_acc = self.local_valid(self.valid_loader, abort_signal) if abort_signal.triggered: return make_reply(ReturnCode.TASK_ABORTED) - self.log_info(fl_ctx, f"validation acc: {val_acc:.4f}") + self.log_info(fl_ctx, f"validation acc ({model_owner}): {val_acc}") self.log_info(fl_ctx, "Evaluation finished. Returning shareable") diff --git a/examples/cifar10/pt/learners/cifar10_scaffold_learner.py b/examples/cifar10/pt/learners/cifar10_scaffold_learner.py new file mode 100644 index 0000000000..83a50e63d3 --- /dev/null +++ b/examples/cifar10/pt/learners/cifar10_scaffold_learner.py @@ -0,0 +1,187 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +import torch +from pt.learners.cifar10_learner import CIFAR10Learner + +from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AlgorithmConstants, AppConstants +from nvflare.app_common.pt.pt_scaffold import PTScaffoldHelper, get_lr_values + + +class CIFAR10ScaffoldLearner(CIFAR10Learner): + def __init__( + self, + dataset_root: str = "./dataset", + aggregation_epochs: int = 1, + train_task_name: str = AppConstants.TASK_TRAIN, + submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL, + lr: float = 1e-2, + fedproxloss_mu: float = 0.0, + central: bool = False, + analytic_sender_id: str = "analytic_sender", + ): + """Simple Scaffold CIFAR-10 Trainer. + Implements the training algorithm proposed in + Karimireddy et al. "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning" + (https://arxiv.org/abs/1910.06378) using functions implemented in `PTScaffoldHelper` class. + + Args: + dataset_root: directory with CIFAR-10 data. + aggregation_epochs: the number of training epochs for a round. Defaults to 1. + train_task_name: name of the task to train the model. + submit_model_task_name: name of the task to submit the best local model. + lr: local learning rate. Float number. Defaults to 1e-2. + fedproxloss_mu: weight for FedProx loss. Float number. Defaults to 0.0 (no FedProx). + central: Bool. Whether to simulate central training. Default False. + analytic_sender_id: id of `AnalyticsSender` if configured as a client component. If configured, TensorBoard events will be fired. Defaults to "analytic_sender". + + Returns: + a Shareable with the updated local model after running `execute()` + or the best local model depending on the specified task. + """ + + CIFAR10Learner.__init__( + self, + dataset_root=dataset_root, + aggregation_epochs=aggregation_epochs, + train_task_name=train_task_name, + submit_model_task_name=submit_model_task_name, + lr=lr, + fedproxloss_mu=fedproxloss_mu, + central=central, + analytic_sender_id=analytic_sender_id, + ) + self.scaffold_helper = PTScaffoldHelper() + + def initialize(self, parts: dict, fl_ctx: FLContext): + # Initialize super class and SCAFFOLD + CIFAR10Learner.initialize(self, parts=parts, fl_ctx=fl_ctx) + self.scaffold_helper.init(model=self.model) + + def local_train(self, fl_ctx, train_loader, model_global, abort_signal: Signal, val_freq: int = 0): + # local_train with SCAFFOLD steps + c_global_para, c_local_para = self.scaffold_helper.get_params() + for epoch in range(self.aggregation_epochs): + if abort_signal.triggered: + return + self.model.train() + epoch_len = len(train_loader) + self.epoch_global = self.epoch_of_start_time + epoch + self.log_info(fl_ctx, f"Local epoch {self.client_id}: {epoch + 1}/{self.aggregation_epochs} (lr={self.lr})") + + for i, (inputs, labels) in enumerate(train_loader): + if abort_signal.triggered: + return + inputs, labels = inputs.to(self.device), labels.to(self.device) + # zero the parameter gradients + self.optimizer.zero_grad() + # forward + backward + optimize + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # FedProx loss term + if self.fedproxloss_mu > 0: + fed_prox_loss = self.criterion_prox(self.model, model_global) + loss += fed_prox_loss + + loss.backward() + self.optimizer.step() + + # SCAFFOLD step + curr_lr = get_lr_values(self.optimizer)[0] + self.scaffold_helper.model_update( + model=self.model, curr_lr=curr_lr, c_global_para=c_global_para, c_local_para=c_local_para + ) + + current_step = epoch_len * self.epoch_global + i + self.writer.add_scalar("train_loss", loss.item(), current_step) + + if val_freq > 0 and epoch % val_freq == 0: + acc = self.local_valid(self.valid_loader, abort_signal, tb_id="val_acc_local_model", fl_ctx=fl_ctx) + if acc > self.best_acc: + self.save_model(is_best=True) + + # Update the SCAFFOLD terms + self.scaffold_helper.terms_update( + model=self.model, + curr_lr=curr_lr, + c_global_para=c_global_para, + c_local_para=c_local_para, + model_global=model_global, + ) + + def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + # return DXO with extra control differences for SCAFFOLD + dxo_collection = from_shareable(shareable) + if dxo_collection.data_kind != DataKind.COLLECTION: + self.log_error( + fl_ctx, + f"SCAFFOLD learner expected shareable to contain a collection of two DXOs " + f"but got data kind {dxo_collection.data_kind}.", + ) + return make_reply(ReturnCode.ERROR) + dxo_global_weights = dxo_collection.data.get(AppConstants.MODEL_WEIGHTS) + dxo_global_ctrl_weights = dxo_collection.data.get(AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL) + if dxo_global_ctrl_weights is None: + self.log_error(fl_ctx, "DXO collection doesn't contain the SCAFFOLD controls!") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + # convert to tensor and load into c_global model + global_ctrl_weights = dxo_global_ctrl_weights.data + for k in global_ctrl_weights.keys(): + global_ctrl_weights[k] = torch.as_tensor(global_ctrl_weights[k]) + self.scaffold_helper.load_global_controls(weights=global_ctrl_weights) + + # modify shareable to only contain global weights + shareable = dxo_global_weights.update_shareable(shareable) # TODO: add set_dxo() method to Shareable + + # local training + result_shareable = super().train(shareable, fl_ctx, abort_signal) + if result_shareable.get_return_code() == ReturnCode.OK: + # get DXO with weight updates from local training + dxo_weights_diff = from_shareable(result_shareable) + + # Create a DXO collection with weights and scaffold controls + dxo_weigths_diff_ctrl = DXO(data_kind=DataKind.WEIGHT_DIFF, data=self.scaffold_helper.get_delta_controls()) + # add same num steps as for model weights + dxo_weigths_diff_ctrl.set_meta_prop( + MetaKey.NUM_STEPS_CURRENT_ROUND, dxo_weights_diff.get_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND) + ) + + collection_data = { + AppConstants.MODEL_WEIGHTS: dxo_weights_diff, + AlgorithmConstants.SCAFFOLD_CTRL_DIFF: dxo_weigths_diff_ctrl, + } + dxo = DXO(data_kind=DataKind.COLLECTION, data=collection_data) + + return dxo.to_shareable() + else: + return result_shareable + + def validate(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + dxo = from_shareable(shareable) + + # If collection, extract only global weights to validate + if dxo.data_kind == DataKind.COLLECTION: + # create a new shareable with only model weights + shareable = copy.deepcopy(shareable) # TODO: Is this the best way? + dxo_global_weights = dxo.data.get(AppConstants.MODEL_WEIGHTS) + shareable = dxo_global_weights.update_shareable(shareable) # TODO: add set_dxo() method to Shareable + + return super().validate(shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal) diff --git a/examples/cifar10/pt/pt_formatter.py b/examples/cifar10/pt/pt_formatter.py deleted file mode 100755 index ba862d590d..0000000000 --- a/examples/cifar10/pt/pt_formatter.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from nvflare.apis.dxo import DataKind, from_bytes -from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.formatter import Formatter -from nvflare.app_common.app_constant import AppConstants - - -class PTFormatter(Formatter): - def __init__(self) -> None: - super().__init__() - self.results = {} - - def format(self, fl_ctx: FLContext) -> str: - """The format function gets validation shareable locations from the dictionary. It loads each shareable, - get the validation results and converts it into human readable string. - - Args: - fl_ctx (FLContext): FLContext object. - - Returns: - str: Human readable validation results. - """ - # Get the val shareables - validation_shareables_dict = fl_ctx.get_prop(AppConstants.VALIDATION_RESULT, {}) - - # Result dictionary - res = {} - - try: - # This is a 2d dictionary with each validation result at - # validation_shareables_dict[data_client][model_client] - for data_client in validation_shareables_dict.keys(): - validation_dict = validation_shareables_dict[data_client] - if validation_dict: - res[data_client] = {} - for model_name in validation_dict.keys(): - dxo_path = validation_dict[model_name] - - # Load the shareable - with open(dxo_path, "rb") as f: - metric_dxo = from_bytes(f.read()) - - # Get metrics from shareable - if metric_dxo and metric_dxo.data_kind == DataKind.METRICS: - metrics = metric_dxo.data - res[data_client][model_name] = metrics - # add any results - print(f"Updating results {res}") - self.results.update(res) - print(f"Updating results {self.results}") - except Exception as e: - self.log_error(fl_ctx, f"Exception: {e.__str__()}") - - return f"{res}" diff --git a/examples/cifar10/pt/pt_model_locator.py b/examples/cifar10/pt/pt_model_locator.py deleted file mode 100755 index 0640cbbc91..0000000000 --- a/examples/cifar10/pt/pt_model_locator.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import List - -import torch - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.fl_constant import FLContextKey -from nvflare.apis.fl_context import FLContext -from nvflare.app_common.abstract.model_locator import ModelLocator -from nvflare.app_common.pt.pt_fed_utils import PTModelPersistenceFormatManager -from nvflare.app_common.app_constant import DefaultCheckpointFileName - - -class PTModelLocator(ModelLocator): - SERVER_MODEL_NAME = "server" - SERVER_BEST_MODEL_NAME = "server_best" - - def __init__( - self, model_dir="app_server", - model_name=DefaultCheckpointFileName.GLOBAL_MODEL, - best_model_name=DefaultCheckpointFileName.BEST_GLOBAL_MODEL - ): - """A ModelLocator that provides the global and best global models. - - Args: - model_dir: directory where global models are saved. - model_name: name of the saved global model. - best_model_name: name of the saved best global model. - - Returns: - a DXO depending on the specified `model_name` in `locate_model()`. - """ - super().__init__() - - self.model_dir = model_dir - self.model_file_name = model_name - self.best_model_file_name = best_model_name - - def get_model_names(self, fl_ctx: FLContext) -> List[str]: - """Returns the list of model names that should be included from server in cross site validation.add() - - Args: - fl_ctx (FLContext): FL Context object. - - Returns: - List[str]: List of model names. - """ - return [PTModelLocator.SERVER_MODEL_NAME, PTModelLocator.SERVER_BEST_MODEL_NAME] - - def locate_model(self, model_name, fl_ctx: FLContext) -> DXO: - dxo = None - engine = fl_ctx.get_engine() - - if model_name in (PTModelLocator.SERVER_MODEL_NAME, PTModelLocator.SERVER_BEST_MODEL_NAME): - run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) - run_dir = engine.get_workspace().get_run_dir(run_number) - model_path = os.path.join(run_dir, self.model_dir) - - if model_name == PTModelLocator.SERVER_BEST_MODEL_NAME: - model_load_path = os.path.join(model_path, self.best_model_file_name) - else: - model_load_path = os.path.join(model_path, self.model_file_name) - model_data = None - try: - model_data = torch.load(model_load_path) - self.log_info(fl_ctx, f"Loaded {model_name} model from {model_load_path}.") - except Exception as e: - self.log_error(fl_ctx, f"Unable to load model: {e}.") - - if model_data is not None: - mgr = PTModelPersistenceFormatManager(model_data) - dxo = DXO(data_kind=DataKind.WEIGHTS, data=mgr.var_dict, meta=mgr.meta) - - return dxo diff --git a/examples/cifar10/pt/validation_json_generator.py b/examples/cifar10/pt/validation_json_generator.py deleted file mode 100644 index 3a9945bc65..0000000000 --- a/examples/cifar10/pt/validation_json_generator.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os.path - -from nvflare.apis.dxo import DataKind, from_shareable -from nvflare.apis.event_type import EventType -from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_context import FLContext -from nvflare.app_common.app_constant import AppConstants -from nvflare.app_common.app_event_type import AppEventType - - -class ValidationJsonGenerator(FLComponent): - def __init__(self, results_dir=AppConstants.CROSS_VAL_DIR, json_file_name="cross_site_val.json"): - """A class to generate a json file with cross-site validation results. - - Args: - results_dir: directory where to save the json with cross-site validation results. - json_file_name: filename of the json to be saved. - - """ - super(ValidationJsonGenerator, self).__init__() - - self.results_dir = results_dir - self.val_results = {} - self.json_file_name = json_file_name - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - self.val_results.clear() - elif event_type == AppEventType.VALIDATION_RESULT_RECEIVED: - model_owner = fl_ctx.get_prop(AppConstants.MODEL_OWNER, None) - data_client = fl_ctx.get_prop(AppConstants.DATA_CLIENT, None) - val_results = fl_ctx.get_prop(AppConstants.VALIDATION_RESULT, None) - - if not model_owner: - self.log_error(fl_ctx, "model_owner unknown. Validation result will not be saved to json") - if not data_client: - self.log_error(fl_ctx, "data_client unknown. Validation result will not be saved to json") - - if val_results: - try: - dxo = from_shareable(val_results) - dxo.validate() - - if dxo.data_kind == DataKind.METRICS: - if data_client not in self.val_results: - self.val_results[data_client] = {} - self.val_results[data_client][model_owner] = dxo.data - else: - self.log_error(fl_ctx, f"Expected dxo of kind METRICS but got {dxo.data_kind} instead.") - except: - self.log_exception(fl_ctx, "Exception in handling validation result.") - else: - self.log_error(fl_ctx, "Validation result not found.") - elif event_type == EventType.END_RUN: - run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_run_number()) - cross_val_res_dir = os.path.join(run_dir, self.results_dir) - if not os.path.exists(cross_val_res_dir): - os.makedirs(cross_val_res_dir) - - res_file_path = os.path.join(cross_val_res_dir, self.json_file_name) - with open(res_file_path, "w") as f: - json.dump(self.val_results, f, indent=4) diff --git a/examples/cifar10/run_experiments.sh b/examples/cifar10/run_experiments.sh index 85e0f04237..cd29a817a9 100755 --- a/examples/cifar10/run_experiments.sh +++ b/examples/cifar10/run_experiments.sh @@ -15,8 +15,11 @@ # FedOpt ./run_secure.sh 8 cifar10_fedopt 7 0.1 +# SCAFFOLD +./run_secure.sh 8 cifar10_scaffold 8 0.1 + # FedAvg + HE -./run_secure.sh 8 cifar10_fedavg_he 8 1.0 +./run_secure.sh 8 cifar10_fedavg_he 9 1.0 # FedAvg with TensorBoard streaming -./run_secure.sh 8 cifar10_fedavg_stream_tb 9 1.0 +./run_secure.sh 8 cifar10_fedavg_stream_tb 10 1.0 diff --git a/nvflare/app_common/app_constant.py b/nvflare/app_common/app_constant.py index b18c42d662..1adaa458c1 100644 --- a/nvflare/app_common/app_constant.py +++ b/nvflare/app_common/app_constant.py @@ -149,3 +149,10 @@ class ValidateType(object): BEFORE_TRAIN_VALIDATE = "before_train_validate" MODEL_VALIDATE = "model_validate" + + +class AlgorithmConstants(object): + + SCAFFOLD_CTRL_DIFF = "scaffold_c_diff" + SCAFFOLD_CTRL_GLOBAL = "scaffold_c_global" + SCAFFOLD_CTRL_AGGREGATOR_ID = "scaffold_ctrl_aggregator" diff --git a/nvflare/app_common/handlers/intime_model_selection_handler.py b/nvflare/app_common/handlers/intime_model_selection_handler.py index 9334dd3dae..f1e4508b91 100644 --- a/nvflare/app_common/handlers/intime_model_selection_handler.py +++ b/nvflare/app_common/handlers/intime_model_selection_handler.py @@ -72,7 +72,7 @@ def _before_accept(self, fl_ctx: FLContext): self.log_exception(fl_ctx, "shareable data is not a valid DXO") return False - if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS): + if dxo.data_kind not in (DataKind.WEIGHT_DIFF, DataKind.WEIGHTS, DataKind.COLLECTION): self.log_debug(fl_ctx, "I cannot handle {}".format(dxo.data_kind)) return False diff --git a/nvflare/app_common/pt/pt_scaffold.py b/nvflare/app_common/pt/pt_scaffold.py new file mode 100644 index 0000000000..e4ac36f340 --- /dev/null +++ b/nvflare/app_common/pt/pt_scaffold.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The SCAFFOLD-related functions are based on https://github.com/Xtra-Computing/NIID-Bench + +# MIT License +# +# Copyright (c) 2021 Yiqun Diao, Qinbin Li +# +# Copyright (c) 2020 International Business Machines +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy + +import torch +from torch.optim import Optimizer + + +def get_lr_values(optimizer: Optimizer): + """ + This function is used to get the learning rates of the optimizer. + """ + return [group["lr"] for group in optimizer.state_dict()["param_groups"]] + + +class PTScaffoldHelper(object): + """Helper to be used with SCAFFOLD components. + Implements the functions used for the algorithm proposed in + Karimireddy et al. "SCAFFOLD: Stochastic Controlled Averaging for Federated Learning" + (https://arxiv.org/abs/1910.06378) using PyTorch. + SCAFFOLD-related functions are based on https://github.com/Xtra-Computing/NIID-Bench. + See also Li et al. "Federated Learning on Non-IID Data Silos: An Experimental Study" + (https://arxiv.org/abs/2102.02079). + """ + + def __init__(self): + # SCAFFOLD control terms + self.cnt = 0 + self.c_global = None + self.c_local = None + self.c_delta_para = None + + def init(self, model): + # create models for SCAFFOLD correction terms + self.c_global = copy.deepcopy(model) + self.c_local = copy.deepcopy(model) + # Initialize correction term with zeros + c_init_para = model.state_dict() + for k in c_init_para.keys(): + c_init_para[k] = torch.zeros_like(c_init_para[k]) + self.c_global.load_state_dict(c_init_para) + self.c_local.load_state_dict(c_init_para) + + def get_params(self): + self.cnt = 0 + # Adapted from https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L371 + c_global_para = self.c_global.state_dict() + c_local_para = self.c_local.state_dict() + return c_global_para, c_local_para + + def model_update(self, model, curr_lr, c_global_para, c_local_para): + # Update model using scaffold controls + # See https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L391 + net_para = model.state_dict() + for key in net_para: + net_para[key] = net_para[key] - curr_lr * (c_global_para[key] - c_local_para[key]) + model.load_state_dict(net_para) + + self.cnt += 1 + + def terms_update(self, model, curr_lr, c_global_para, c_local_para, model_global): + # Update the local scaffold controls + # See https://github.com/Xtra-Computing/NIID-Bench/blob/main/experiments.py#L403 + + c_new_para = self.c_local.state_dict() + self.c_delta_para = copy.deepcopy(self.c_local.state_dict()) + global_model_para = model_global.state_dict() + net_para = model.state_dict() + for key in net_para: + c_new_para[key] = ( + c_new_para[key] - c_global_para[key] + (global_model_para[key] - net_para[key]) / (self.cnt * curr_lr) + ) + self.c_delta_para[key] = (c_new_para[key] - c_local_para[key]).cpu().numpy() + self.c_local.load_state_dict(c_new_para) + + def load_global_controls(self, weights): + self.c_global.load_state_dict(weights) + + def get_delta_controls(self): + if self.c_delta_para is None: + raise ValueError("c_delta_para hasn't been computed yet!") + return self.c_delta_para diff --git a/nvflare/app_common/workflows/scatter_and_gather_scaffold.py b/nvflare/app_common/workflows/scatter_and_gather_scaffold.py new file mode 100644 index 0000000000..74b53ed363 --- /dev/null +++ b/nvflare/app_common/workflows/scatter_and_gather_scaffold.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import traceback + +import numpy as np + +from nvflare.apis.dxo import DXO, DataKind, from_shareable +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Task +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.model import model_learnable_to_dxo +from nvflare.app_common.app_constant import AlgorithmConstants, AppConstants +from nvflare.app_common.app_event_type import AppEventType +from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather + + +class ScatterAndGatherScaffold(ScatterAndGather): + def __init__( + self, + min_clients: int = 1, + num_rounds: int = 5, + start_round: int = 0, + wait_time_after_min_received: int = 10, + aggregator_id=AppConstants.DEFAULT_AGGREGATOR_ID, + persistor_id=AppConstants.DEFAULT_PERSISTOR_ID, + shareable_generator_id=AppConstants.DEFAULT_SHAREABLE_GENERATOR_ID, + train_task_name=AppConstants.TASK_TRAIN, + train_timeout: int = 0, + ignore_result_error: bool = True, + ): + """FederatedAveraging Workflow. The ScatterAndGatherScaffold workflow defines Federated training on all clients. + The model persistor (persistor_id) is used to load the initial global model which is sent to all clients. + Each clients sends it's updated weights after local training which is aggregated (aggregator_id). The + shareable generator is used to convert the aggregated weights to shareable and shareable back to weights. + The model_persistor also saves the model after training. + + Args: + min_clients (int, optional): Min number of clients in training. Defaults to 1. + num_rounds (int, optional): The total number of training rounds. Defaults to 5. + start_round (int, optional): Start round for training. Defaults to 0. + wait_time_after_min_received (int, optional): Time to wait before beginning aggregation after + contributions received. Defaults to 10. + train_timeout (int, optional): Time to wait for clients to do local training. + aggregator_id (str, optional): ID of the aggregator component. Defaults to "aggregator". + persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". + shareable_generator_id (str, optional): ID of the shareable generator. Defaults to "shareable_generator". + train_task_name (str, optional): Name of the train task. Defaults to "train". + """ + + super().__init__( + min_clients=min_clients, + num_rounds=num_rounds, + start_round=start_round, + wait_time_after_min_received=wait_time_after_min_received, + aggregator_id=aggregator_id, + persistor_id=persistor_id, + shareable_generator_id=shareable_generator_id, + train_task_name=train_task_name, + train_timeout=train_timeout, + ignore_result_error=ignore_result_error, + ) + + # for SCAFFOLD + self.aggregator_ctrl = None + self._global_ctrl_weights = None + + def start_controller(self, fl_ctx: FLContext) -> None: + super().start_controller(fl_ctx=fl_ctx) + self.log_info(fl_ctx, "Initializing ScatterAndGatherScaffold workflow.") + + # for SCAFFOLD + if not self._global_weights: + self.system_panic("Global weights not available!", fl_ctx) + return + + self._global_ctrl_weights = copy.deepcopy(self._global_weights["weights"]) + # Initialize correction term with zeros + for k in self._global_ctrl_weights.keys(): + self._global_ctrl_weights[k] = np.zeros_like(self._global_ctrl_weights[k]) + # TODO: Print some stats of the correction magnitudes + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext) -> None: + try: + + self.log_info(fl_ctx, "Beginning ScatterAndGatherScaffold training phase.") + self._phase = AppConstants.PHASE_TRAIN + + fl_ctx.set_prop(AppConstants.PHASE, self._phase, private=True, sticky=False) + fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=False) + self.fire_event(AppEventType.TRAINING_STARTED, fl_ctx) + + for self._current_round in range(self._start_round, self._start_round + self._num_rounds): + if self._check_abort_signal(fl_ctx, abort_signal): + return + + self.log_info(fl_ctx, f"Round {self._current_round} started.") + fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True) + fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=False) + self.fire_event(AppEventType.ROUND_STARTED, fl_ctx) + + # Create train_task + # get DXO with global model weights + dxo_global_weights = model_learnable_to_dxo(self._global_weights) + + # add global SCAFFOLD controls using a DXO collection + dxo_global_ctrl = DXO(data_kind=DataKind.WEIGHT_DIFF, data=self._global_ctrl_weights) + dxo_dict = { + AppConstants.MODEL_WEIGHTS: dxo_global_weights, + AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL: dxo_global_ctrl, + } + dxo_collection = DXO(data_kind=DataKind.COLLECTION, data=dxo_dict) + data_shareable = dxo_collection.to_shareable() + + # add meta information + data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) + data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) + data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round) + + train_task = Task( + name=self.train_task_name, + data=data_shareable, + props={}, + timeout=self._train_timeout, + before_task_sent_cb=self._prepare_train_task_data, + result_received_cb=self._process_train_result, + ) + + self.broadcast_and_wait( + task=train_task, + min_responses=self._min_clients, + wait_time_after_min_received=self._wait_time_after_min_received, + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + if self._check_abort_signal(fl_ctx, abort_signal): + return + + self.fire_event(AppEventType.BEFORE_AGGREGATION, fl_ctx) + aggr_result = self.aggregator.aggregate(fl_ctx) + + # extract aggregated weights and controls + collection_dxo = from_shareable(aggr_result) + dxo_aggr_result = collection_dxo.data.get(AppConstants.MODEL_WEIGHTS) + if not dxo_aggr_result: + self.log_error(fl_ctx, "Aggregated model weights are missing!") + return + dxo_ctrl_aggr_result = collection_dxo.data.get(AlgorithmConstants.SCAFFOLD_CTRL_DIFF) + if not dxo_ctrl_aggr_result: + self.log_error(fl_ctx, "Aggregated model weight controls are missing!") + return + + fl_ctx.set_prop(AppConstants.AGGREGATION_RESULT, aggr_result, private=True, sticky=False) + self.fire_event(AppEventType.AFTER_AGGREGATION, fl_ctx) + + if self._check_abort_signal(fl_ctx, abort_signal): + return + + # update global model using shareable generator + self.fire_event(AppEventType.BEFORE_SHAREABLE_TO_LEARNABLE, fl_ctx) + self._global_weights = self.shareable_gen.shareable_to_learnable(dxo_aggr_result.to_shareable(), fl_ctx) + + # update SCAFFOLD global controls + ctr_diff = dxo_ctrl_aggr_result.data + for v_name, v_value in ctr_diff.items(): + self._global_ctrl_weights[v_name] += v_value + fl_ctx.set_prop( + AlgorithmConstants.SCAFFOLD_CTRL_GLOBAL, self._global_ctrl_weights, private=True, sticky=True + ) + + fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self._global_weights, private=True, sticky=True) + fl_ctx.sync_sticky() + self.fire_event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE, fl_ctx) + + if self._check_abort_signal(fl_ctx, abort_signal): + return + + self.fire_event(AppEventType.BEFORE_LEARNABLE_PERSIST, fl_ctx) + self.persistor.save(self._global_weights, fl_ctx) + self.fire_event(AppEventType.AFTER_LEARNABLE_PERSIST, fl_ctx) + + self.fire_event(AppEventType.ROUND_DONE, fl_ctx) + self.log_info(fl_ctx, f"Round {self._current_round} finished.") + + self._phase = AppConstants.PHASE_FINISHED + self.log_info(fl_ctx, "Finished ScatterAndGatherScaffold Training.") + except BaseException as e: + traceback.print_exc() + error_msg = f"Exception in ScatterAndGatherScaffold control_flow: {e}" + self.log_exception(fl_ctx, error_msg) + self.system_panic(str(e), fl_ctx)