|
275 | 275 | "from nvflare.job_config.script_runner import ScriptRunner\n",
|
276 | 276 | "from nvflare.app_common.workflows.fedavg import FedAvg\n",
|
277 | 277 | "\n",
|
278 |
| - "job = FedJob(name=\"cifar10_fedavg\")" |
| 278 | + "job = FedJob(name=\"cifar10_pt_fedavg\")" |
279 | 279 | ]
|
280 | 280 | },
|
281 | 281 | {
|
|
378 | 378 | },
|
379 | 379 | {
|
380 | 380 | "cell_type": "markdown",
|
381 |
| - "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", |
| 381 | + "id": "6059b304", |
382 | 382 | "metadata": {},
|
383 | 383 | "source": [
|
384 |
| - "#### 7. Add clients\n", |
385 |
| - "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", |
386 |
| - "\n", |
387 |
| - "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity.\n", |
388 |
| - "We can also specify, which GPU should be used to run this client, which is helpful for simulated environments." |
| 384 | + "#### 7. Add TB Event\n", |
| 385 | + "Add tensorboard logging to clients" |
389 | 386 | ]
|
390 | 387 | },
|
391 | 388 | {
|
392 | 389 | "cell_type": "code",
|
393 | 390 | "execution_count": null,
|
394 |
| - "id": "ad5d36fe-9ae5-43c3-80bc-2cdc66bf7a7e", |
| 391 | + "id": "51d8bcda", |
395 | 392 | "metadata": {},
|
396 | 393 | "outputs": [],
|
397 | 394 | "source": [
|
| 395 | + "from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n", |
| 396 | + "\n", |
398 | 397 | "for i in range(n_clients):\n",
|
399 |
| - " executor = ScriptRunner(\n", |
400 |
| - " script=\"src/cifar10_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n", |
401 |
| - " )\n", |
402 |
| - " job.to(id=\"event_to_fed\", obj=executor, target=f\"site-{i+1}\")" |
| 398 | + " component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n", |
| 399 | + " job.to(id=\"event_to_fed\", obj=component, target=f\"site-{i+1}\")" |
403 | 400 | ]
|
404 | 401 | },
|
405 | 402 | {
|
406 | 403 | "cell_type": "markdown",
|
407 |
| - "id": "a56abcd6-4e97-4a60-8894-2760f8815a03", |
| 404 | + "id": "7c95e3f6", |
408 | 405 | "metadata": {},
|
409 | 406 | "source": [
|
410 |
| - "#### 8. Add TB Event\n", |
411 |
| - "Add tensorboard logging to clients" |
| 407 | + "#### OPTIONAL: Define a FedAvgJob\n", |
| 408 | + "\n", |
| 409 | + "Alternatively, we can replace steps 2-7 and instead use the `FedAvgJob`.\n", |
| 410 | + "The `FedAvgJob` automatically configures the `FedAvg`` server controller, along the other components for model persistence, model selection, and TensorBoard streaming.\n" |
412 | 411 | ]
|
413 | 412 | },
|
414 | 413 | {
|
415 | 414 | "cell_type": "code",
|
416 | 415 | "execution_count": null,
|
417 |
| - "id": "a8a733e0-c0a9-4c36-b49d-16b20c2df7f6", |
| 416 | + "id": "c4dfc3e7", |
418 | 417 | "metadata": {},
|
419 | 418 | "outputs": [],
|
420 | 419 | "source": [
|
421 |
| - "from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent\n", |
| 420 | + "from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob\n", |
| 421 | + "\n", |
| 422 | + "n_clients = 2\n", |
422 | 423 | "\n",
|
| 424 | + "# Create FedAvg Job with initial model\n", |
| 425 | + "job = FedAvgJob(\n", |
| 426 | + " name=\"cifar10_pt_fedavg\",\n", |
| 427 | + " num_rounds=2,\n", |
| 428 | + " n_clients=n_clients,\n", |
| 429 | + " initial_model=Net(),\n", |
| 430 | + ")" |
| 431 | + ] |
| 432 | + }, |
| 433 | + { |
| 434 | + "cell_type": "markdown", |
| 435 | + "id": "548966c2-90bf-47ad-91d2-5c6c22c3c4f0", |
| 436 | + "metadata": {}, |
| 437 | + "source": [ |
| 438 | + "#### 8. Add client ScriptRunners\n", |
| 439 | + "Next, we can use the `ScriptRunner` and send it to each of the clients to run our training script.\n", |
| 440 | + "\n", |
| 441 | + "Note that our script could have additional input arguments, such as batch size or data path, but we don't use them here for simplicity." |
| 442 | + ] |
| 443 | + }, |
| 444 | + { |
| 445 | + "cell_type": "code", |
| 446 | + "execution_count": null, |
| 447 | + "id": "ad5d36fe-9ae5-43c3-80bc-2cdc66bf7a7e", |
| 448 | + "metadata": {}, |
| 449 | + "outputs": [], |
| 450 | + "source": [ |
423 | 451 | "for i in range(n_clients):\n",
|
424 |
| - " component = ConvertToFedEvent(events_to_convert=[\"analytix_log_stats\"], fed_event_prefix=\"fed.\")\n", |
425 |
| - " job.to(component, f\"site-{i+1}\")" |
| 452 | + " runner = ScriptRunner(\n", |
| 453 | + " script=\"src/cifar10_fl.py\", script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n", |
| 454 | + " )\n", |
| 455 | + " job.to(runner, f\"site-{i+1}\")" |
426 | 456 | ]
|
427 | 457 | },
|
428 | 458 | {
|
|
432 | 462 | "source": [
|
433 | 463 | "That's it!\n",
|
434 | 464 | "\n",
|
435 |
| - "#### 9 Optionally export the job\n", |
| 465 | + "#### 9. Optionally export the job\n", |
436 | 466 | "Now, we could export the job and submit it to a real NVFlare deployment using the [Admin client](https://nvflare.readthedocs.io/en/main/real_world_fl/operation.html) or [FLARE API](https://nvflare.readthedocs.io/en/main/real_world_fl/flare_api.html)."
|
437 | 467 | ]
|
438 | 468 | },
|
|
452 | 482 | "metadata": {},
|
453 | 483 | "source": [
|
454 | 484 | "#### 10. Run FL Simulation\n",
|
455 |
| - "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. The results will be saved in the specified `workdir`." |
| 485 | + "Finally, we can run our FedJob in simulation using NVFlare's [simulator](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/fl_simulator.html) under the hood. We can also specify which GPU should be used to run this client, which is helpful for simulated environments. The results will be saved in the specified `workdir`." |
456 | 486 | ]
|
457 | 487 | },
|
458 | 488 | {
|
|
482 | 512 | "metadata": {},
|
483 | 513 | "outputs": [],
|
484 | 514 | "source": [
|
485 |
| - "! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_fedavg" |
| 515 | + "! nvflare simulator -w /tmp/nvflare/jobs/workdir -n 2 -t 2 -gpu 0 /tmp/nvflare/jobs/job_config/cifar10_pt_fedavg" |
486 | 516 | ]
|
487 | 517 | },
|
488 | 518 | {
|
|
0 commit comments