Skip to content

Commit

Permalink
Update flower examples (#2871)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Aug 29, 2024
1 parent 7a843fb commit ba391aa
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 31 deletions.
6 changes: 3 additions & 3 deletions examples/hello-world/hello-flower/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ pip install ./flwr-pt/

Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
```bash
python job.py
python job.py --job_name "flwr-pt" --content_dir "./flwr-pt"
```

## 2.2 Run a simulation with TensorBoard streaming

To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies.
```bash
pip install ./flwr-pt-metrics/
pip install ./flwr-pt-tb/
```

Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.

```bash
python job_with_metric.py
python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics --use_client_api
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "flwr_pt_tb_streaming"
name = "flwr_pt_tb"
version = "1.0.0"
description = ""
license = "Apache-2.0"
Expand All @@ -12,6 +12,7 @@ dependencies = [
"nvflare~=2.5.0rc",
"torch==2.2.1",
"torchvision==0.17.1",
"tensorboard"
]

[tool.hatch.build.targets.wheel]
Expand All @@ -21,8 +22,8 @@ packages = ["."]
publisher = "nvidia"

[tool.flwr.app.components]
serverapp = "flwr_pt_tb_streaming.server:app"
clientapp = "flwr_pt_tb_streaming.client:app"
serverapp = "flwr_pt_tb.server:app"
clientapp = "flwr_pt_tb.client:app"

[tool.flwr.app.config]
num-server-rounds = 3
Expand Down
35 changes: 31 additions & 4 deletions examples/hello-world/hello-flower/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser

from nvflare.app_opt.flower.flower_job import FlowerJob
from nvflare.client.api import ClientAPIType
from nvflare.client.api_spec import CLIENT_API_TYPE_KEY

if __name__ == "__main__":
job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt")

job.export_job("jobs")
job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2)
def main():
parser = ArgumentParser()
parser.add_argument("--job_name", type=str, required=True)
parser.add_argument("--content_dir", type=str, required=True)
parser.add_argument("--stream_metrics", action="store_true")
parser.add_argument("--use_client_api", action="store_true")
parser.add_argument("--export_dir", type=str, default="jobs")
parser.add_argument("--workdir", type=str, default="/tmp/nvflare/hello-flower")
args = parser.parse_args()

env = {}
if args.use_client_api:
env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value}

job = FlowerJob(
name=args.job_name,
flower_content=args.content_dir,
stream_metrics=args.stream_metrics,
extra_env=env,
)

job.export_job(args.export_dir)
job.simulator_run(args.workdir, gpu="0", n_clients=2)


if __name__ == "__main__":
main()
21 changes: 0 additions & 21 deletions examples/hello-world/hello-flower/job_with_metric.py

This file was deleted.

0 comments on commit ba391aa

Please sign in to comment.