diff --git a/examples/hello-world/hello-flower/README.md b/examples/hello-world/hello-flower/README.md index 235d641b84..b896913c5e 100644 --- a/examples/hello-world/hello-flower/README.md +++ b/examples/hello-world/hello-flower/README.md @@ -33,19 +33,19 @@ pip install ./flwr-pt/ Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator. ```bash -python job.py +python job.py --job_name "flwr-pt" --content_dir "./flwr-pt" ``` ## 2.2 Run a simulation with TensorBoard streaming To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies. ```bash -pip install ./flwr-pt-metrics/ +pip install ./flwr-pt-tb/ ``` Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming. ```bash -python job_with_metric.py +python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics --use_client_api ``` diff --git a/examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/__init__.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/__init__.py similarity index 100% rename from examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/__init__.py rename to examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/__init__.py diff --git a/examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/client.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py similarity index 100% rename from examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/client.py rename to examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py diff --git a/examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/server.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py similarity index 100% rename from examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/server.py rename to examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py diff --git a/examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/task.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/task.py similarity index 100% rename from examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/task.py rename to examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/task.py diff --git a/examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml b/examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml similarity index 82% rename from examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml rename to examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml index 36c45c8032..12b99c5c63 100644 --- a/examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml +++ b/examples/hello-world/hello-flower/flwr-pt-tb/pyproject.toml @@ -3,7 +3,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "flwr_pt_tb_streaming" +name = "flwr_pt_tb" version = "1.0.0" description = "" license = "Apache-2.0" @@ -12,6 +12,7 @@ dependencies = [ "nvflare~=2.5.0rc", "torch==2.2.1", "torchvision==0.17.1", + "tensorboard" ] [tool.hatch.build.targets.wheel] @@ -21,8 +22,8 @@ packages = ["."] publisher = "nvidia" [tool.flwr.app.components] -serverapp = "flwr_pt_tb_streaming.server:app" -clientapp = "flwr_pt_tb_streaming.client:app" +serverapp = "flwr_pt_tb.server:app" +clientapp = "flwr_pt_tb.client:app" [tool.flwr.app.config] num-server-rounds = 3 diff --git a/examples/hello-world/hello-flower/job.py b/examples/hello-world/hello-flower/job.py index 7e9a9f0726..558c29ed4e 100644 --- a/examples/hello-world/hello-flower/job.py +++ b/examples/hello-world/hello-flower/job.py @@ -12,10 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +from argparse import ArgumentParser + from nvflare.app_opt.flower.flower_job import FlowerJob +from nvflare.client.api import ClientAPIType +from nvflare.client.api_spec import CLIENT_API_TYPE_KEY -if __name__ == "__main__": - job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt") - job.export_job("jobs") - job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2) +def main(): + parser = ArgumentParser() + parser.add_argument("--job_name", type=str, required=True) + parser.add_argument("--content_dir", type=str, required=True) + parser.add_argument("--stream_metrics", action="store_true") + parser.add_argument("--use_client_api", action="store_true") + parser.add_argument("--export_dir", type=str, default="jobs") + parser.add_argument("--workdir", type=str, default="/tmp/nvflare/hello-flower") + args = parser.parse_args() + + env = {} + if args.use_client_api: + env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value} + + job = FlowerJob( + name=args.job_name, + flower_content=args.content_dir, + stream_metrics=args.stream_metrics, + extra_env=env, + ) + + job.export_job(args.export_dir) + job.simulator_run(args.workdir, gpu="0", n_clients=2) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/hello-flower/job_with_metric.py b/examples/hello-world/hello-flower/job_with_metric.py deleted file mode 100644 index 03969a804a..0000000000 --- a/examples/hello-world/hello-flower/job_with_metric.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from nvflare.app_opt.flower.flower_job import FlowerJob - -if __name__ == "__main__": - job = FlowerJob(name="flwr-pt-metrics", flower_content="./flwr-pt-metrics", stream_metrics=True) - - job.export_job("jobs") - job.simulator_run("/tmp/nvflare/flwr-pt-metrics", gpu="0", n_clients=2)