-
Notifications
You must be signed in to change notification settings - Fork 181
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add flower metrics streaming example (#2764)
* Add flower metrics streaing example * Fix format * Use context and RecordSet * Undo stuff * Update to new style * Update hello-flwr-pt_tb_streaming * Remove debug msgs * Update readme * Use flower job * Add missing code * Make client api type an arg
- Loading branch information
1 parent
e956fea
commit 9737f53
Showing
17 changed files
with
424 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
85 changes: 85 additions & 0 deletions
85
examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from flwr.client import ClientApp, NumPyClient | ||
from flwr.common import Context | ||
from flwr.common.record import MetricsRecord, RecordSet | ||
|
||
from .task import DEVICE, Net, get_weights, load_data, set_weights, test, train | ||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data() | ||
|
||
import nvflare.client as flare | ||
|
||
# initializes NVFlare interface | ||
from nvflare.client.tracking import SummaryWriter | ||
|
||
flare.init() | ||
|
||
|
||
# Define FlowerClient and client_fn | ||
class FlowerClient(NumPyClient): | ||
def __init__(self, context: Context): | ||
super().__init__() | ||
self.writer = SummaryWriter() | ||
self.set_context(context) | ||
if "step" not in context.state.metrics_records: | ||
self.set_step(0) | ||
|
||
def set_step(self, step: int): | ||
context = self.get_context() | ||
context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})}) | ||
self.set_context(context) | ||
|
||
def get_step(self): | ||
context = self.get_context() | ||
return int(context.state.metrics_records["step"]["step"]) | ||
|
||
def fit(self, parameters, config): | ||
step = self.get_step() | ||
set_weights(net, parameters) | ||
results = train(net, trainloader, testloader, epochs=1, device=DEVICE) | ||
|
||
self.writer.add_scalar("train_loss", results["train_loss"], step) | ||
self.writer.add_scalar("train_accuracy", results["train_accuracy"], step) | ||
self.writer.add_scalar("val_loss", results["val_loss"], step) | ||
self.writer.add_scalar("val_accuracy", results["val_accuracy"], step) | ||
|
||
self.set_step(step + 1) | ||
|
||
return get_weights(net), len(trainloader.dataset), results | ||
|
||
def evaluate(self, parameters, config): | ||
set_weights(net, parameters) | ||
step = self.get_step() | ||
loss, accuracy = test(net, testloader) | ||
|
||
self.writer.add_scalar("test_loss", loss, step) | ||
self.writer.add_scalar("test_accuracy", accuracy, step) | ||
|
||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
def client_fn(context: Context): | ||
"""Create and return an instance of Flower `Client`.""" | ||
return FlowerClient(context).to_client() | ||
|
||
|
||
# Flower ClientApp | ||
app = ClientApp( | ||
client_fn=client_fn, | ||
) |
65 changes: 65 additions & 0 deletions
65
examples/hello-world/hello-flower/flwr-pt-metrics/flwr_pt_tb_streaming/server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import List, Tuple | ||
|
||
from flwr.common import Metrics, ndarrays_to_parameters | ||
from flwr.server import ServerApp, ServerConfig | ||
from flwr.server.strategy import FedAvg | ||
|
||
from .task import Net, get_weights | ||
|
||
|
||
# Define metric aggregation function | ||
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: | ||
examples = [num_examples for num_examples, _ in metrics] | ||
|
||
# Multiply accuracy of each client by number of examples used | ||
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] | ||
train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics] | ||
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] | ||
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] | ||
|
||
# Aggregate and return custom metric (weighted average) | ||
return { | ||
"train_loss": sum(train_losses) / sum(examples), | ||
"train_accuracy": sum(train_accuracies) / sum(examples), | ||
"val_loss": sum(val_losses) / sum(examples), | ||
"val_accuracy": sum(val_accuracies) / sum(examples), | ||
} | ||
|
||
|
||
# Initialize model parameters | ||
ndarrays = get_weights(Net()) | ||
parameters = ndarrays_to_parameters(ndarrays) | ||
|
||
|
||
# Define strategy | ||
strategy = FedAvg( | ||
fraction_fit=1.0, # Select all available clients | ||
fraction_evaluate=0.0, # Disable evaluation | ||
min_available_clients=2, | ||
fit_metrics_aggregation_fn=weighted_average, | ||
initial_parameters=parameters, | ||
) | ||
|
||
|
||
# Define config | ||
config = ServerConfig(num_rounds=3) | ||
|
||
|
||
# Flower ServerApp | ||
app = ServerApp( | ||
config=config, | ||
strategy=strategy, | ||
) |
File renamed without changes.
34 changes: 34 additions & 0 deletions
34
examples/hello-world/hello-flower/flwr-pt-metrics/pyproject.toml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "flwr_pt_tb_streaming" | ||
version = "1.0.0" | ||
description = "" | ||
license = "Apache-2.0" | ||
dependencies = [ | ||
"flwr[simulation]>=1.11.0,<2.0", | ||
"nvflare~=2.5.0rc", | ||
"torch==2.2.1", | ||
"torchvision==0.17.1", | ||
] | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["."] | ||
|
||
[tool.flwr.app] | ||
publisher = "nvidia" | ||
|
||
[tool.flwr.app.components] | ||
serverapp = "flwr_pt_tb_streaming.server:app" | ||
clientapp = "flwr_pt_tb_streaming.client:app" | ||
|
||
[tool.flwr.app.config] | ||
num-server-rounds = 3 | ||
|
||
[tool.flwr.federations] | ||
default = "local-simulation" | ||
|
||
[tool.flwr.federations.local-simulation] | ||
options.num-supernodes = 2 |
14 changes: 14 additions & 0 deletions
14
examples/hello-world/hello-flower/flwr-pt/flwr_pt/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""flwr_pt.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
106 changes: 106 additions & 0 deletions
106
examples/hello-world/hello-flower/flwr-pt/flwr_pt/task.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections import OrderedDict | ||
from logging import INFO | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from flwr.common.logger import log | ||
from torch.utils.data import DataLoader | ||
from torchvision.datasets import CIFAR10 | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
|
||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" | ||
|
||
def __init__(self) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return self.fc3(x) | ||
|
||
|
||
def load_data(): | ||
"""Load CIFAR-10 (training and test set).""" | ||
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
trainset = CIFAR10("./data", train=True, download=True, transform=trf) | ||
testset = CIFAR10("./data", train=False, download=True, transform=trf) | ||
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset) | ||
|
||
|
||
def train(net, trainloader, valloader, epochs, device): | ||
"""Train the model on the training set.""" | ||
log(INFO, "Starting training...") | ||
net.to(device) # move model to GPU if available | ||
criterion = torch.nn.CrossEntropyLoss().to(device) | ||
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
net.train() | ||
for _ in range(epochs): | ||
for images, labels in trainloader: | ||
images, labels = images.to(device), labels.to(device) | ||
optimizer.zero_grad() | ||
loss = criterion(net(images), labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
train_loss, train_acc = test(net, trainloader) | ||
val_loss, val_acc = test(net, valloader) | ||
|
||
results = { | ||
"train_loss": train_loss, | ||
"train_accuracy": train_acc, | ||
"val_loss": val_loss, | ||
"val_accuracy": val_acc, | ||
} | ||
return results | ||
|
||
|
||
def test(net, testloader): | ||
"""Validate the model on the test set.""" | ||
net.to(DEVICE) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
with torch.no_grad(): | ||
for images, labels in testloader: | ||
outputs = net(images.to(DEVICE)) | ||
labels = labels.to(DEVICE) | ||
loss += criterion(outputs, labels).item() | ||
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
return loss, accuracy | ||
|
||
|
||
def get_weights(net): | ||
return [val.cpu().numpy() for _, val in net.state_dict().items()] | ||
|
||
|
||
def set_weights(net, parameters): | ||
params_dict = zip(net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | ||
net.load_state_dict(state_dict, strict=True) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvflare.app_opt.flower.flower_job import FlowerJob | ||
|
||
if __name__ == "__main__": | ||
job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt") | ||
|
||
job.export_job("jobs") | ||
job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2) |
Oops, something went wrong.