Skip to content

Commit 8e83701

Browse files
YuanTingHsiehgslama12chesterxgchen
authored
[2.6] Update flwr example (#3580)
Update flwr example ### Description Update flwr example and cherry-pick #3495 and #3550 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Georg Slamanig <[email protected]> Co-authored-by: Chester Chen <[email protected]>
1 parent e60800e commit 8e83701

File tree

10 files changed

+86
-49
lines changed

10 files changed

+86
-49
lines changed

docs/publications_and_talks.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Publications: 2022
3232
* **2022-10** `Auto-FedRL: Federated Hyperparameter Optimization for Multi-institutional Medical Image Segmentation <https://arxiv.org/abs/2203.06338>`__ (`ECCV 2022 <https://eccv2022.ecva.net/>`__)
3333
* **2022-10** `Joint Multi Organ and Tumor Segmentation from Partial Labels Using Federated Learning <https://link.springer.com/chapter/10.1007/978-3-031-18523-6_6>`__ (`DeCaF @ MICCAI 2022 <https://decaf-workshop.github.io/decaf-2022/>`__)
3434
* **2022-10** `Split-U-Net: Preventing Data Leakage in Split Learning for Collaborative Multi-modal Brain Tumor Segmentation <https://arxiv.org/abs/2208.10553>`__ (`DeCaF @ MICCAI 2022 <https://decaf-workshop.github.io/decaf-2022/>`__)
35-
* **2022-06** `Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation <https://openaccess.thecvf.com/content/CVPR2022/papers/Xu_Closing_the_Generalization_Gap_of_Cross-Silo_Federated_Medical_Image_Segmentation_CVPR_2022_paper.pdf>`__ (`CVPR 2022 <https://cvpr2022.thecvf.com/>`__)
35+
* **2022-06** `Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation <https://openaccess.thecvf.com/content/CVPR2022/papers/Xu_Closing_the_Generalization_Gap_of_Cross-Silo_Federated_Medical_Image_Segmentation_CVPR_2022_paper.pdf>`__ (CVPR 2022)
3636
* **2022-02** `Do Gradient Inversion Attacks Make Federated Learning Unsafe? <https://arxiv.org/abs/2202.06924>`__ (Preprint)
3737

3838
Publications: 2021

examples/hello-world/hello-flower/README.md

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Flower App (PyTorch) in NVIDIA FLARE
22

3-
In this example, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
3+
In this example, we run 2 clients and 1 server using NVFlare's simulator.
44

55
## Preconditions
66

7-
To run Flower code in NVFlare, we created a job, including an app with the following custom folder content
7+
Following https://github.com/adap/flower/tree/main/examples/quickstart-pytorch we prepare the following flower app:
8+
89
```bash
9-
$ tree jobs/hello-flwr-pt/app/custom
10+
$ tree flwr-pt
1011

1112
├── flwr_pt
1213
│ ├── client.py # <-- contains `ClientApp`
@@ -15,38 +16,42 @@ $ tree jobs/hello-flwr-pt/app/custom
1516
│ └── task.py # <-- task-specific code (model, data)
1617
└── pyproject.toml # <-- Flower project file
1718
```
18-
Note, this code is adapted from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.
19+
20+
To be run inside NVFlare, we need to add the following sections to "pyproject.toml":
21+
```
22+
[tool.flwr.app.config]
23+
num-server-rounds = 3
24+
25+
[tool.flwr.federations]
26+
default = "local-simulation"
27+
28+
[tool.flwr.federations.local-simulation]
29+
options.num-supernodes = 2
30+
address = "127.0.0.1:9093"
31+
insecure = true
32+
```
33+
34+
You can adjust the num-server-rounds.
35+
The number `options.num-supernodes` should match the number of NVFlare clients defined in [job.py](./job.py), e.g., `job.simulator_run(args.workdir, gpu="0", n_clients=2)`.
1936

2037
## 1. Install dependencies
2138
If you haven't already, we recommend creating a virtual environment.
2239
```bash
2340
python3 -m venv nvflare_flwr
2441
source nvflare_flwr/bin/activate
42+
pip install -r ./requirements.txt
2543
```
26-
We recommend installing an older version of NumPy as torch/torchvision doesn't support NumPy 2 at this time.
27-
```bash
28-
pip install numpy==1.26.4
29-
```
30-
## 2.1 Run a simulation
3144

32-
To run flwr-pt job with NVFlare, we first need to install its dependencies.
33-
```bash
34-
pip install ./flwr-pt/
35-
```
45+
## 2.1 Run flwr-pt with NVFlare simulation
3646

37-
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
47+
We run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
3848
```bash
3949
python job.py --job_name "flwr-pt" --content_dir "./flwr-pt"
4050
```
4151

42-
## 2.2 Run a simulation with TensorBoard streaming
43-
44-
To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies.
45-
```bash
46-
pip install ./flwr-pt-tb/
47-
```
52+
## 2.2 Run flwr-pt with NVFlare simulation and NVFlare's TensorBoard streaming
4853

49-
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
54+
We run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
5055
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.
5156

5257
```bash
@@ -59,16 +64,19 @@ tensorboard --logdir /tmp/nvflare/hello-flower
5964
```
6065
![tensorboard training curve](./train.png)
6166

62-
## Notes
63-
Make sure your `pyproject.toml` files in the Flower apps contain an "address" field. This needs to be present as the `--federation-config` option of the `flwr run` command tries to override the `“address”` field.
64-
Your `pyproject.toml` should include a section similar to this:
67+
68+
## 3. Run with real deployment
69+
70+
First, check real-world deployment guide: https://nvflare.readthedocs.io/en/main/real_world_fl/overview.html
71+
72+
Second, export the corresponding NVFlare job:
73+
```bash
74+
python job.py --job_name "flwr-pt" --content_dir "./flwr-pt" --export_job --export_dir "./jobs"
6575
```
66-
[tool.flwr.federations]
67-
default = "xxx"
6876

69-
[tool.flwr.federations.xxx]
70-
options.num-supernodes = 2
71-
address = "127.0.0.1:9093"
72-
insecure = false
77+
An NVFlare job will be generated at "./jobs" folder.
78+
79+
Then you can copy it inside the admin console's transfer folder and then run:
80+
```bash
81+
submit_job flwr-pt
7382
```
74-
The number `options.num-supernodes` should match the number of NVFlare clients defined in [job.py](./job.py), e.g., `job.simulator_run(args.workdir, gpu="0", n_clients=2)`.

examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from flwr.client import ClientApp, NumPyClient
1717
from flwr.common import Context
18-
from flwr.common.record import MetricsRecord, RecordSet
18+
from flwr.common.record import MetricRecord, RecordDict
1919

2020
from .task import DEVICE, Net, get_weights, load_data, set_weights, test, train
2121

@@ -36,16 +36,16 @@ def __init__(self, context: Context):
3636
self.writer = SummaryWriter()
3737
self.flwr_context = context
3838

39-
if "step" not in context.state.metrics_records:
39+
if "step" not in context.state.metric_records:
4040
self.set_step(0)
4141

4242
def set_step(self, step: int):
43-
record = RecordSet()
44-
record["step"] = MetricsRecord({"step": step})
43+
record = RecordDict()
44+
record["step"] = MetricRecord({"step": step})
4545
self.flwr_context.state = record
4646

4747
def get_step(self):
48-
return int(self.flwr_context.state.metrics_records["step"]["step"])
48+
return int(self.flwr_context.state.metric_records["step"]["step"])
4949

5050
def fit(self, parameters, config):
5151
step = self.get_step()

examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
5353
initial_parameters=parameters,
5454
)
5555

56-
# Define config
57-
config = ServerConfig(num_rounds=3)
58-
5956

6057
# Flower ServerApp
6158
def server_fn(context: Context):
62-
return ServerAppComponents(
63-
strategy=strategy,
64-
config=config,
65-
)
59+
# Read from config
60+
num_rounds = context.run_config["num-server-rounds"]
61+
62+
# Define config
63+
config = ServerConfig(num_rounds=num_rounds)
64+
return ServerAppComponents(strategy=strategy, config=config)
6665

6766

67+
# Create ServerApp
6868
app = ServerApp(server_fn=server_fn)

examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
requires = ["hatchling"]
33
build-backend = "hatchling.build"
44

5+
# Tested with:
6+
# flwr==1.20.0
7+
# nvflare==2.6.1
8+
# torch==2.7.1
9+
# torchvision==0.22.1
10+
# tensorboard==2.20.0
11+
512
[project]
613
name = "flwr_pt_tb"
714
version = "1.0.0"
@@ -34,4 +41,4 @@ default = "local-simulation"
3441
[tool.flwr.federations.local-simulation]
3542
options.num-supernodes = 2
3643
address = "127.0.0.1:9093"
37-
insecure = true
44+
insecure = true

examples/hello-world/hello-flower/flwr-pt/pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
requires = ["hatchling"]
33
build-backend = "hatchling.build"
44

5+
# Tested with:
6+
# flwr==1.20.0
7+
# nvflare==2.6.1
8+
# torch==2.7.1
9+
# torchvision==0.22.1
10+
# tensorboard==2.20.0
11+
512
[project]
613
name = "flwr_pt"
714
version = "1.0.0"
@@ -34,4 +41,4 @@ default = "local-simulation"
3441
[tool.flwr.federations.local-simulation]
3542
options.num-supernodes = 2
3643
address = "127.0.0.1:9093"
37-
insecure = true
44+
insecure = true

examples/hello-world/hello-flower/job.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def main():
2626
parser.add_argument("--content_dir", type=str, required=True)
2727
parser.add_argument("--stream_metrics", action="store_true")
2828
parser.add_argument("--use_client_api", action="store_true")
29+
parser.add_argument("--export_job", action="store_true")
2930
parser.add_argument("--export_dir", type=str, default="jobs")
3031
parser.add_argument("--workdir", type=str, default="/tmp/nvflare/hello-flower")
3132
args = parser.parse_args()
@@ -36,15 +37,20 @@ def main():
3637
# only external client api works with the current flower integration
3738
env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value}
3839

40+
num_of_clients = 2
41+
3942
job = FlowerPyTorchJob(
4043
name=args.job_name,
4144
flower_content=args.content_dir,
4245
stream_metrics=args.stream_metrics,
46+
min_clients=num_of_clients,
4347
extra_env=env,
4448
)
4549

46-
job.export_job(args.export_dir)
47-
job.simulator_run(os.path.join(args.workdir, job.name), gpu="0", n_clients=2)
50+
if args.export_job:
51+
job.export_job(args.export_dir)
52+
else:
53+
job.simulator_run(os.path.join(args.workdir, job.name), gpu="0", n_clients=num_of_clients)
4854

4955

5056
if __name__ == "__main__":
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
flwr[simulation]>=1.16,<2.0
2+
nvflare>=2.6.0
3+
torch
4+
torchvision
5+
tensorboard

nvflare/app_opt/flower/flower_job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def __init__(
9494
self.to_clients(obj=flower_content)
9595

9696
if not stream_metrics:
97+
conf = ExternalConfigurator(component_ids=[])
98+
self.to_clients(conf, "client_api_config_preparer")
9799
return
98100

99101
# add required components for metrics streaming

research/fed-sm/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This directory contains the code for the personalized federated learning algorithm FedSM described in
44

55
### Closing the Generalization Gap of Cross-silo Federated Medical Image Segmentation ([arXiv:2203.10144](https://arxiv.org/abs/2203.10144))
6-
Accepted to [CVPR2022](https://cvpr2022.thecvf.com/).
6+
Accepted to CVPR2022.
77

88
###### Abstract:
99

@@ -88,3 +88,5 @@ BibTeX
8888
year={2022}
8989
}
9090
```
91+
92+

0 commit comments

Comments
 (0)