Skip to content

Commit

Permalink
Cifar10 scaffold (#127)
Browse files Browse the repository at this point in the history
* SCAFFOLD Integration

scaffold learner

scaffold configs

initialize the scaffold terms on the client

update README to include SCAFFOLD description

add scaffold experiment

update urls

add LICENSE from NIID-Bench

update README and plots to reflect SCAFFOLD experiment

refactor scaffold computation steps into their own member functions

remove unused import

refactor to use ScaffoldLearner class

Scaffold learner depends on PyTorch. Renamed as such

use two standard aggregators, inherit SCAFFOLD workflow from standard ScatterAndGather

remove custom aggregator

use update_shareable

use multi-class inheritance for scaffold learner

use new aggregator that supports several DXOs

fix updating call for global model

simplify dxo/shareable handling

use built-in PTFileModelLocator

fix formatting

add PT Formatter

use built-in validation json generator; remove custom formatter

remove custom Json validation generator

update license

restore run_secure.sh

restore project yml

restore main branch learner executor

add SCAFFOLD link to example readme

remove custom validation json generator

print model owner info during validation

remove custom formatter

remove special handling of validation dxo. Not needed anymore

use zeros_like() to initialize scaffold terms

* use scaffold helper class

* formatting

* make scaffold function args consistent
  • Loading branch information
holgerroth authored Feb 16, 2022
1 parent b7e98b7 commit 7ab7b1e
Show file tree
Hide file tree
Showing 26 changed files with 741 additions and 345 deletions.
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The provided examples cover different aspects of [NVIDIA FLARE](https://nvidia.g

## 2. FL algorithms
* [Federated Learning with CIFAR-10](./cifar10/README.md)
* Includes examples of using [FedAvg](https://arxiv.org/abs/1602.05629), [FedProx](https://arxiv.org/abs/1812.06127), [FedOpt](https://arxiv.org/abs/2003.00295), [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/), and streaming of TensorBoard metrics to the server during training.
* Includes examples of using [FedAvg](https://arxiv.org/abs/1602.05629), [FedProx](https://arxiv.org/abs/1812.06127), [FedOpt](https://arxiv.org/abs/2003.00295), [SCAFFOLD](https://arxiv.org/abs/1910.06378), [homomorphic encryption](https://developer.nvidia.com/blog/federated-learning-with-homomorphic-encryption/), and streaming of TensorBoard metrics to the server during training.

## 3. Medical Image Analysis
* [Hello MONAI](./hello-monai/README.md)
Expand Down
46 changes: 25 additions & 21 deletions examples/cifar10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,23 @@ for example
cat ./workspaces/poc_workspace/server/run_2/cross_site_val/cross_site_val.json
```

### 3.4: Advanced FL algorithms (FedProx and FedOpt)
### 3.4: Advanced FL algorithms (FedProx, FedOpt, and SCAFFOLD)

Next, let's try some different FL algorithms on a more heterogeneous split:

FedProx (https://arxiv.org/abs/1812.06127), adds a regularizer to the CIFAR10Trainer loss (`fedproxloss_mu`)`:
[FedProx](https://arxiv.org/abs/1812.06127) adds a regularizer to the loss used in `CIFAR10Learner` (`fedproxloss_mu`)`:
```
./run_poc.sh 8 cifar10_fedprox 6 0.1
```
FedOpt (https://arxiv.org/abs/2003.00295), uses a new ShareableGenerator to update the global model on the server using a PyTorch optimizer.
[FedOpt](https://arxiv.org/abs/2003.00295) uses a new ShareableGenerator to update the global model on the server using a PyTorch optimizer.
Here SGD with momentum and cosine learning rate decay:
```
./run_poc.sh 8 cifar10_fedopt 7 0.1
```
[SCAFFOLD](https://arxiv.org/abs/1910.06378) uses a slightly modified version of the CIFAR-10 Learner implementation, namely the `CIFAR10ScaffoldLearner`, which adds a correction term during local training following the [implementation](https://github.com/Xtra-Computing/NIID-Bench) as described in [Li et al.](https://arxiv.org/abs/2102.02079)
```
./run_poc.sh 8 cifar10_scaffold 8 0.1
```

### 3.5 Secure aggregation using homomorphic encryption

Expand All @@ -147,7 +151,7 @@ Next we run FedAvg using homomorphic encryption (HE) for secure aggregation on t
FedAvg with HE:
```
./run_secure.sh 8 cifar10_fedavg_he 8 1.0
./run_secure.sh 8 cifar10_fedavg_he 9 1.0
```

> **_NOTE:_** Currently, FedOpt is not supported with HE as it would involve running the optimizer on encrypted values.
Expand All @@ -173,8 +177,8 @@ that HE does not impact the performance accuracy of FedAvg significantly while a
| Config | Alpha | Val score |
| ----------- | ----------- | ----------- |
| cifar10_central | 1.0 | 0.8798 |
| cifar10_fedavg | 1.0 | 0.8873 |
| cifar10_fedavg_he | 1.0 | 0.8864 |
| cifar10_fedavg | 1.0 | 0.8854 |
| cifar10_fedavg_he | 1.0 | 0.8897 |

![Central vs. FedAvg](./figs/central_vs_fedavg_he.png)

Expand All @@ -185,33 +189,33 @@ This can be observed in the resulting performance of the FedAvg algorithms.

| Config | Alpha | Val score |
| ----------- | ----------- | ----------- |
| cifar10_fedavg | 1.0 | 0.8873 |
| cifar10_fedavg | 0.5 | 0.8726 |
| cifar10_fedavg | 0.3 | 0.8315 |
| cifar10_fedavg | 0.1 | 0.7726 |
| cifar10_fedavg | 1.0 | 0.8854 |
| cifar10_fedavg | 0.5 | 0.8633 |
| cifar10_fedavg | 0.3 | 0.8350 |
| cifar10_fedavg | 0.1 | 0.7733 |

![Impact of client data heterogeneity](./figs/fedavg_alpha.png)

### 4.3 FedProx vs. FedOpt
### 4.3 FedAvg vs. FedProx vs. FedOpt vs. SCAFFOLD

Finally, we are comparing an `alpha` setting of 0.1, causing a high client data heterogeneity and its
impact on more advanced FL algorithms, namely FedProx and FedOpt. Both achieve a better performance compared to FedAvg
with the same `alpha` setting but FedOpt shows a better convergence rate by utilizing SGD with momentum
to update the global model on the server, and achieves a better performance with the same amount of training steps.
Finally, we compare an `alpha` setting of 0.1, causing a high client data heterogeneity and its
impact on more advanced FL algorithms, namely FedProx, FedOpt, and SCAFFOLD. FedProx and SCAFFOLD achieve better performance compared to FedAvg and FedProx with the same `alpha` setting. However, FedOpt and SCAFFOLD show markedly better convergence rates. SCAFFOLD achieves that by adding a correction term when updating the client models, while FedOpt utilizes SGD with momentum
to update the global model on the server. Both achieve better performance with the same number of training steps as FedAvg/FedProx.

| Config | Alpha | Val score |
|------------------| ----------- | ----------- |
| cifar10_fedavg | 0.1 | 0.7726 |
| cifar10_fedprox | 0.1 | 0.7512 |
| cifar10_fedopt | 0.1 | 0.7986 |
|------------------| ----------- | ---------- |
| cifar10_fedavg | 0.1 | 0.7733 |
| cifar10_fedprox | 0.1 | 0.7615 |
| cifar10_fedopt | 0.1 | 0.8013 |
| cifar10_scaffold | 0.1 | 0.8222 |

![FedProx vs. FedOpt](./figs/fedopt_fedprox.png)
![FedProx vs. FedOpt](./figs/fedopt_fedprox_scaffold.png)


## 5. Streaming TensorBoard metrics to the server

In a real-world scenario, the researcher won't have access to the TensorBoard events of the individual clients. In order to visualize the training performance in a central place, `AnalyticsSender`, `ConvertToFedEvent` on the client, and `TBAnalyticsReceiver` on the server can be used. For an example using FedAvg and metric streaming during training, run:
```
./run_poc.sh 8 cifar10_fedavg_stream_tb 9 1.0
./run_poc.sh 8 cifar10_fedavg_stream_tb 10 1.0
```
Using this configuration, a `tb_events` folder will be created under the `run_*` folder of the server that includes all the TensorBoard event values of the different clients.
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
Expand All @@ -72,7 +69,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
Expand All @@ -72,7 +69,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
Expand All @@ -72,7 +69,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
},
{
Expand Down Expand Up @@ -77,7 +74,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
Expand All @@ -91,7 +88,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@
},
{
"id": "model_locator",
"path": "pt.pt_model_locator.PTModelLocator",
"args": {}
},
{
"id": "formatter",
"path": "pt.pt_formatter.PTFormatter",
"args": {}
"name": "PTFileModelLocator",
"args": {
"pt_persistor_id": "persistor"
}
},
{
"id": "json_generator",
"path": "pt.validation_json_generator.ValidationJsonGenerator",
"name": "ValidationJsonGenerator",
"args": {}
}
],
Expand All @@ -72,7 +69,6 @@
"name": "CrossSiteModelEval",
"args": {
"model_locator_id": "model_locator",
"formatter_id": "formatter",
"submit_model_timeout": 600,
"validation_timeout": 6000,
"cleanup_models": true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"format_version": 2,

"DATASET_ROOT": "/tmp/cifar10_data",

"executors": [
{
"tasks": [
"train", "submit_model", "validate"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_common.executors.learner_executor.LearnerExecutor",
"args": {
"learner_id": "cifar10-learner"
}
}
}
],

"task_result_filters": [
],
"task_data_filters": [
],

"components": [
{
"id": "cifar10-learner",
"path": "pt.learners.cifar10_scaffold_learner.CIFAR10ScaffoldLearner",
"args": {
"dataset_root": "{DATASET_ROOT}",
"aggregation_epochs": 4,
"lr": 1e-2
}
}
]
}
Loading

0 comments on commit 7ab7b1e

Please sign in to comment.