Skip to content

Commit 0bb9f6e

Browse files
authored
site, docs, example updates (#2894)
1 parent d4e914c commit 0bb9f6e

File tree

12 files changed

+272
-87
lines changed

12 files changed

+272
-87
lines changed

docs/programming_guide/controllers/model_controller.rst

+48-4
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ ModelController API
77
The FLARE :mod:`ModelController<nvflare.app_common.workflows.model_controller>` API provides an easy way for users to write and customize FLModel-based controller workflows.
88

99
* Highly flexible with a simple API (run routine and basic communication and utility functions)
10-
* :ref:`fl_model`for the communication data structure, everything else is pure Python
10+
* :ref:`fl_model` for the communication data structure, everything else is pure Python
1111
* Option to support pre-existing components and FLARE-specific functionalities
1212

1313
.. note::
1414

1515
The ModelController API is a high-level API meant to simplify writing workflows.
16-
If users prefer or need the full flexibility of the Controller with all the capabilites of FLARE functions, refer to the :ref:`controllers`.
16+
If users prefer or need the full flexibility of the Controller with all the capabilities of FLARE functions, refer to the :ref:`controllers`.
1717

1818

1919
Core Concepts
@@ -70,7 +70,7 @@ Here is an example of the FedAvg workflow using the :class:`BaseFedAvg<nvflare.a
7070
results, aggregate_fn=self.aggregate_fn
7171
) # using default aggregate_fn with `WeightedAggregationHelper`. Can overwrite self.aggregate_fn with signature Callable[List[FLModel], FLModel]
7272
73-
# update global model with agggregation results
73+
# update global model with aggregation results
7474
model = self.update_model(model, aggregate_results)
7575
7676
# save model (by default uses persistor, can provide custom method)
@@ -119,7 +119,7 @@ The :ref:`fl_model` is standardized data structure object that is sent along wit
119119

120120
The :ref:`fl_model` object can be any type of data depending on the specific task.
121121
For example, in the "train" and "validate" tasks we send the model parameters along with the task so the target clients can train and validate the model.
122-
However in many other tasks that do not involve sending the model (e.g. "submit_model"), the :ref:`fl_model` can contain any type of data (e.g. metadata, metrics etc.) or may be not be needed at all.
122+
However in many other tasks that do not involve sending the model (e.g. "submit_model"), the :ref:`fl_model` can contain any type of data (e.g. metadata, metrics etc.) or may not be needed at all.
123123

124124

125125
send_model_and_wait
@@ -142,6 +142,50 @@ A callback with the signature ``Callable[[FLModel], None]`` can be passed in, wh
142142
The task is standing until either ``min_responses`` have been received, or ``timeout`` time has passed.
143143
Since this call is asynchronous, the Controller :func:`get_num_standing_tasks<nvflare.apis.impl.controller.Controller.get_num_standing_tasks>` method can be used to get the number of standing tasks for synchronization purposes.
144144

145+
For example, in the :github_nvflare_link:`CrossSiteEval <app_common/workflows/cross_site_eval.py>` workflow, the tasks are asynchronously sent with :func:`send_model<nvflare.app_common.workflows.model_controller.ModelController.send_model>` to get each client's model.
146+
Then through a callback, the clients' models are sent to the other clients for validation.
147+
Finally, the workflow waits for all standing tasks to complete with :func:`get_num_standing_tasks<nvflare.apis.impl.controller.Controller.get_num_standing_tasks>`.
148+
Below is an example of how these functions can be used. For more details view the implementation of :github_nvflare_link:`CrossSiteEval <app_common/workflows/cross_site_eval.py>`.
149+
150+
151+
.. code-block:: python
152+
153+
class CrossSiteEval(ModelController):
154+
...
155+
def run(self) -> None:
156+
...
157+
# Create submit_model task and broadcast to all participating clients
158+
self.send_model(
159+
task_name=AppConstants.TASK_SUBMIT_MODEL,
160+
data=data,
161+
targets=self._participating_clients,
162+
timeout=self._submit_model_timeout,
163+
callback=self._receive_local_model_cb,
164+
)
165+
...
166+
# Wait for all standing tasks to complete, since we used non-blocking `send_model()`
167+
while self.get_num_standing_tasks():
168+
if self.abort_signal.triggered:
169+
self.info("Abort signal triggered. Finishing cross site validation.")
170+
return
171+
self.debug("Checking standing tasks to see if cross site validation finished.")
172+
time.sleep(self._task_check_period)
173+
174+
self.save_results()
175+
self.info("Stop Cross-Site Evaluation.")
176+
177+
def _receive_local_model_cb(self, model: FLModel):
178+
# Send this model to all clients to validate
179+
model.meta[AppConstants.MODEL_OWNER] = model_name
180+
self.send_model(
181+
task_name=AppConstants.TASK_VALIDATION,
182+
data=model,
183+
targets=self._participating_clients,
184+
timeout=self._validation_timeout,
185+
callback=self._receive_val_result_cb,
186+
)
187+
...
188+
145189
146190
Saving & Loading
147191
================

docs/release_notes/flare_250.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ scientists' experience working with FLARE. The new API covers client, server and
99

1010
Model Controller API
1111
--------------------
12-
The new Model Controller API greatly simplifies the experience of developing new federated learning workflows. Users can simply subclass
12+
The new :ref:`model_controller` greatly simplifies the experience of developing new federated learning workflows. Users can simply subclass
1313
the ModelController to develop new workflows. The new API doesn't require users to know the details of NVFlare constructs except for FLModel
1414
class, where it is simply a data structure that contains model weights, optimization parameters and metadata.
1515

@@ -104,7 +104,7 @@ federated stats will be very helpful.
104104

105105
FedAvg Early Stopping Example
106106
------------------------------
107-
The `FedAvg Early Stopping example <https://github.com/NVIDIA/NVFlare/pull/2648>`_ tries to demonstrate that with the new server-side model
107+
The :github_nvflare_link:`FedAvg Early Stopping example <examples/hello-world/hello-fedavg>` tries to demonstrate that with the new server-side model
108108
controller API, it is very easy to change the control conditions and adjust workflows with a few lines of python code.
109109

110110
Tensorflow Algorithms & Examples

examples/advanced/job_api/pt/src/cifar10_lightning_fl.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def predict_dataloader(self):
7171
def main():
7272
model = LitNet()
7373
cifar10_dm = CIFAR10DataModule()
74-
if torch.cuda.is_available():
75-
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
76-
else:
77-
trainer = Trainer(max_epochs=1, devices=None)
74+
trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")
7875

7976
# (2) patch the lightning trainer
8077
flare.patch(trainer)

examples/getting_started/pt/nvflare_lightning_getting_started.ipynb

+40-10
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@
333333
"from nvflare.job_config.script_runner import ScriptRunner\n",
334334
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
335335
"\n",
336-
"job = FedJob(name=\"cifar10_fedavg_lightning\")"
336+
"job = FedJob(name=\"cifar10_lightning_fedavg\")"
337337
]
338338
},
339339
{
@@ -412,16 +412,46 @@
412412
"That completes the components that need to be defined on the server."
413413
]
414414
},
415+
{
416+
"cell_type": "markdown",
417+
"id": "32686782",
418+
"metadata": {},
419+
"source": [
420+
"#### OPTIONAL: Define a FedAvgJob\n",
421+
"\n",
422+
"Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n",
423+
"The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence and model selection."
424+
]
425+
},
426+
{
427+
"cell_type": "code",
428+
"execution_count": null,
429+
"id": "02fde3ae",
430+
"metadata": {},
431+
"outputs": [],
432+
"source": [
433+
"from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n",
434+
"\n",
435+
"n_clients = 2\n",
436+
"\n",
437+
"# Create FedAvg Job with initial model\n",
438+
"job = FedAvgJob(\n",
439+
" name=\"cifar10_lightning_fedavg\",\n",
440+
" num_rounds=2,\n",
441+
" n_clients=n_clients,\n",
442+
" initial_model=LitNet(),\n",
443+
")"
444+
]
445+
},
415446
{
416447
"cell_type": "markdown",
417448
"id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0",
418449
"metadata": {},
419450
"source": [
420-
"#### 5. Add clients\n",
451+
"#### 6. Add client ScriptRunners\n",
421452
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
422453
"\n",
423-
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity.\n",
424-
"We can also specify, which GPU should be used to run this client, which is helpful for simulated environments."
454+
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity."
425455
]
426456
},
427457
{
@@ -432,10 +462,10 @@
432462
"outputs": [],
433463
"source": [
434464
"for i in range(n_clients):\n",
435-
" executor = ScriptRunner(\n",
465+
" runner = ScriptRunner(\n",
436466
" script=\"src/cifar10_lightning_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n",
437467
" )\n",
438-
" job.to(executor, f\"site-{i+1}\")"
468+
" job.to(runner, f\"site-{i+1}\")"
439469
]
440470
},
441471
{
@@ -445,7 +475,7 @@
445475
"source": [
446476
"That's it!\n",
447477
"\n",
448-
"#### 6. Optionally export the job\n",
478+
"#### 7. Optionally export the job\n",
449479
"Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html). "
450480
]
451481
},
@@ -464,8 +494,8 @@
464494
"id": "9ac3f0a8-06bb-4bea-89d3-4a5fc5b76c63",
465495
"metadata": {},
466496
"source": [
467-
"#### 7. Run FL Simulation\n",
468-
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. The results will be saved in the specified `workdir`."
497+
"#### 8. Run FL Simulation\n",
498+
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`."
469499
]
470500
},
471501
{
@@ -495,7 +525,7 @@
495525
"metadata": {},
496526
"outputs": [],
497527
"source": [
498-
"! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_fedavg_lightning"
528+
"! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_lightning_fedavg"
499529
]
500530
}
501531
],

examples/getting_started/pt/nvflare_pt_getting_started.ipynb

+52-22
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@
275275
"from nvflare.job_config.script_runner import ScriptRunner\n",
276276
"from nvflare.app_common.workflows.fedavg import FedAvg\n",
277277
"\n",
278-
"job = FedJob(name=\"cifar10_fedavg\")"
278+
"job = FedJob(name=\"cifar10_pt_fedavg\")"
279279
]
280280
},
281281
{
@@ -378,51 +378,81 @@
378378
},
379379
{
380380
"cell_type": "markdown",
381-
"id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0",
381+
"id": "6059b304",
382382
"metadata": {},
383383
"source": [
384-
"#### 7. Add clients\n",
385-
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
386-
"\n",
387-
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity.\n",
388-
"We can also specify, which GPU should be used to run this client, which is helpful for simulated environments."
384+
"#### 7. Add TB Event\n",
385+
"Add tensorboard logging to clients"
389386
]
390387
},
391388
{
392389
"cell_type": "code",
393390
"execution_count": null,
394-
"id": "ad5d36fe-9ae5-43c3-80bc-2cdc66bf7a7e",
391+
"id": "51d8bcda",
395392
"metadata": {},
396393
"outputs": [],
397394
"source": [
395+
"from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n",
396+
"\n",
398397
"for i in range(n_clients):\n",
399-
" executor = ScriptRunner(\n",
400-
" script=\"src/cifar10_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n",
401-
" )\n",
402-
" job.to(id=\"event_to_fed\", obj=executor, target=f\"site-{i+1}\")"
398+
" component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n",
399+
" job.to(id=\"event_to_fed\", obj=component, target=f\"site-{i+1}\")"
403400
]
404401
},
405402
{
406403
"cell_type": "markdown",
407-
"id": "a56abcd6-4e97-4a60-8894-2760f8815a03",
404+
"id": "7c95e3f6",
408405
"metadata": {},
409406
"source": [
410-
"#### 8. Add TB Event\n",
411-
"Add tensorboard logging to clients"
407+
"#### OPTIONAL: Define a FedAvgJob\n",
408+
"\n",
409+
"Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n",
410+
"The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence, model selection, and TensorBoard streaming.\n"
412411
]
413412
},
414413
{
415414
"cell_type": "code",
416415
"execution_count": null,
417-
"id": "a8a733e0-c0a9-4c36-b49d-16b20c2df7f6",
416+
"id": "c4dfc3e7",
418417
"metadata": {},
419418
"outputs": [],
420419
"source": [
421-
"from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n",
420+
"from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n",
421+
"\n",
422+
"n_clients = 2\n",
422423
"\n",
424+
"# Create FedAvg Job with initial model\n",
425+
"job = FedAvgJob(\n",
426+
" name=\"cifar10_pt_fedavg\",\n",
427+
" num_rounds=2,\n",
428+
" n_clients=n_clients,\n",
429+
" initial_model=Net(),\n",
430+
")"
431+
]
432+
},
433+
{
434+
"cell_type": "markdown",
435+
"id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0",
436+
"metadata": {},
437+
"source": [
438+
"#### 8. Add client ScriptRunners\n",
439+
"Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n",
440+
"\n",
441+
"Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity."
442+
]
443+
},
444+
{
445+
"cell_type": "code",
446+
"execution_count": null,
447+
"id": "ad5d36fe-9ae5-43c3-80bc-2cdc66bf7a7e",
448+
"metadata": {},
449+
"outputs": [],
450+
"source": [
423451
"for i in range(n_clients):\n",
424-
" component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n",
425-
" job.to(component, f\"site-{i+1}\")"
452+
" runner = ScriptRunner(\n",
453+
" script=\"src/cifar10_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n",
454+
" )\n",
455+
" job.to(runner, f\"site-{i+1}\")"
426456
]
427457
},
428458
{
@@ -432,7 +462,7 @@
432462
"source": [
433463
"That's it!\n",
434464
"\n",
435-
"#### 9 Optionally export the job\n",
465+
"#### 9. Optionally export the job\n",
436466
"Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html)."
437467
]
438468
},
@@ -452,7 +482,7 @@
452482
"metadata": {},
453483
"source": [
454484
"#### 10. Run FL Simulation\n",
455-
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. The results will be saved in the specified `workdir`."
485+
"Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`."
456486
]
457487
},
458488
{
@@ -482,7 +512,7 @@
482512
"metadata": {},
483513
"outputs": [],
484514
"source": [
485-
"! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_fedavg"
515+
"! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_pt_fedavg"
486516
]
487517
},
488518
{

examples/getting_started/pt/src/cifar10_lightning_fl.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,7 @@ def predict_dataloader(self):
7171
def main():
7272
model = LitNet()
7373
cifar10_dm = CIFAR10DataModule()
74-
if torch.cuda.is_available():
75-
trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
76-
else:
77-
trainer = Trainer(max_epochs=1, devices=None)
74+
trainer = Trainer(max_epochs=1, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")
7875

7976
# (2) patch the lightning trainer
8077
flare.patch(trainer)

0 commit comments

Comments
 (0)