Skip to content

Commit

Permalink
Add flower metrics streaming example (#2764)
Browse files Browse the repository at this point in the history
* Add flower metrics streaing example

* Fix format

* Use context and RecordSet

* Undo stuff

* Update to new style

* Update hello-flwr-pt_tb_streaming

* Remove debug msgs

* Update readme

* Use flower job

* Add missing code

* Make client api type an arg
  • Loading branch information
YuanTingHsieh authored Aug 28, 2024
1 parent e956fea commit 9737f53
Show file tree
Hide file tree
Showing 17 changed files with 424 additions and 48 deletions.
27 changes: 21 additions & 6 deletions examples/hello-world/hello-flower/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,35 @@ $ tree jobs/hello-flwr-pt/app/custom
```
Note, this code is adapted from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.

## Install dependencies
## 1. Install dependencies
If you haven't already, we recommend creating a virtual environment.
```bash
python3 -m venv nvflare_flwr
source nvflare_flwr/bin/activate
```
To run a job with NVFlare, we first need to install its dependencies.

## 2.1 Run a simulation

To run flwr-pt job with NVFlare, we first need to install its dependencies.
```bash
pip install ./jobs/hello-flwr-pt/app/custom
pip install ./flwr-pt/
```

## Run a simulation

Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
```bash
nvflare simulator jobs/hello-flwr-pt -n 2 -t 2 -w /tmp/nvflare/flwr
python job.py
```

## 2.2 Run a simulation with TensorBoard streaming

To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies.
```bash
pip install ./flwr-pt-metrics/
```

Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.

```bash
python job_with_metric.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from flwr.common.record import MetricsRecord, RecordSet

from .task import DEVICE, Net, get_weights, load_data, set_weights, test, train

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()

import nvflare.client as flare

# initializes NVFlare interface
from nvflare.client.tracking import SummaryWriter

flare.init()


# Define FlowerClient and client_fn
class FlowerClient(NumPyClient):
def __init__(self, context: Context):
super().__init__()
self.writer = SummaryWriter()
self.set_context(context)
if "step" not in context.state.metrics_records:
self.set_step(0)

def set_step(self, step: int):
context = self.get_context()
context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})})
self.set_context(context)

def get_step(self):
context = self.get_context()
return int(context.state.metrics_records["step"]["step"])

def fit(self, parameters, config):
step = self.get_step()
set_weights(net, parameters)
results = train(net, trainloader, testloader, epochs=1, device=DEVICE)

self.writer.add_scalar("train_loss", results["train_loss"], step)
self.writer.add_scalar("train_accuracy", results["train_accuracy"], step)
self.writer.add_scalar("val_loss", results["val_loss"], step)
self.writer.add_scalar("val_accuracy", results["val_accuracy"], step)

self.set_step(step + 1)

return get_weights(net), len(trainloader.dataset), results

def evaluate(self, parameters, config):
set_weights(net, parameters)
step = self.get_step()
loss, accuracy = test(net, testloader)

self.writer.add_scalar("test_loss", loss, step)
self.writer.add_scalar("test_accuracy", accuracy, step)

return loss, len(testloader.dataset), {"accuracy": accuracy}


def client_fn(context: Context):
"""Create and return an instance of Flower `Client`."""
return FlowerClient(context).to_client()


# Flower ClientApp
app = ClientApp(
client_fn=client_fn,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple

from flwr.common import Metrics, ndarrays_to_parameters
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg

from .task import Net, get_weights


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
examples = [num_examples for num_examples, _ in metrics]

# Multiply accuracy of each client by number of examples used
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics]
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]

# Aggregate and return custom metric (weighted average)
return {
"train_loss": sum(train_losses) / sum(examples),
"train_accuracy": sum(train_accuracies) / sum(examples),
"val_loss": sum(val_losses) / sum(examples),
"val_accuracy": sum(val_accuracies) / sum(examples),
}


# Initialize model parameters
ndarrays = get_weights(Net())
parameters = ndarrays_to_parameters(ndarrays)


# Define strategy
strategy = FedAvg(
fraction_fit=1.0, # Select all available clients
fraction_evaluate=0.0, # Disable evaluation
min_available_clients=2,
fit_metrics_aggregation_fn=weighted_average,
initial_parameters=parameters,
)


# Define config
config = ServerConfig(num_rounds=3)


# Flower ServerApp
app = ServerApp(
config=config,
strategy=strategy,
)
34 changes: 34 additions & 0 deletions examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "flwr_pt_tb_streaming"
version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.11.0,<2.0",
"nvflare~=2.5.0rc",
"torch==2.2.1",
"torchvision==0.17.1",
]

[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "nvidia"

[tool.flwr.app.components]
serverapp = "flwr_pt_tb_streaming.server:app"
clientapp = "flwr_pt_tb_streaming.client:app"

[tool.flwr.app.config]
num-server-rounds = 3

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 2
14 changes: 14 additions & 0 deletions examples/hello-world/hello-flower/flwr-pt/flwr_pt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""flwr_pt."""
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context

Expand Down
106 changes: 106 additions & 0 deletions examples/hello-world/hello-flower/flwr-pt/flwr_pt/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from logging import INFO

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr.common.logger import log
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)


def train(net, trainloader, valloader, epochs, device):
"""Train the model on the training set."""
log(INFO, "Starting training...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net.train()
for _ in range(epochs):
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
loss.backward()
optimizer.step()

train_loss, train_acc = test(net, trainloader)
val_loss, val_acc = test(net, valloader)

results = {
"train_loss": train_loss,
"train_accuracy": train_acc,
"val_loss": val_loss,
"val_accuracy": val_acc,
}
return results


def test(net, testloader):
"""Validate the model on the test set."""
net.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
labels = labels.to(DEVICE)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def get_weights(net):
return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
21 changes: 21 additions & 0 deletions examples/hello-world/hello-flower/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.app_opt.flower.flower_job import FlowerJob

if __name__ == "__main__":
job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt")

job.export_job("jobs")
job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2)
Loading

0 comments on commit 9737f53

Please sign in to comment.