Skip to content

Commit

Permalink
Merge pull request #245 from EricDinging/fix-test
Browse files Browse the repository at this point in the history
Fix fed-yogi executor model download
  • Loading branch information
fanlai0990 committed Dec 17, 2023
2 parents e62ad70 + 7918a28 commit 0f90918
Show file tree
Hide file tree
Showing 10 changed files with 729 additions and 472 deletions.
2 changes: 1 addition & 1 deletion benchmark/configs/openimage/openimage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ job_conf:
- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace
- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace
- model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2
- gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- eval_interval: 30 # How many rounds to run a testing on the testing set
- rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds
- filter_less: 21 # Remove clients w/ less than 21 samples
Expand Down
2 changes: 1 addition & 1 deletion benchmark/configs/others/local_dp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ job_conf:
# - device_avail_file: /users/JIACHEN/FLPerf-Cluster/client_datamap/client_behave_trace
- sample_mode: random # Client selection: random, oort, random by default
- model: resnet18 # Models: shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2
- gradient_policy: fedavg # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- gradient_policy: fed-avg # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- eval_interval: 10 # How many rounds to run a testing on the testing set
- rounds: 500 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds
- filter_less: 20
Expand Down
2 changes: 1 addition & 1 deletion benchmark/configs/others/rl_conf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ job_conf:
- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace
- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace
- model: dqn # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2
- gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- eval_interval: 50 # How many rounds to run a testing on the testing set
- rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds
- filter_less: 21 # Remove clients w/ less than 21 samples
Expand Down
2 changes: 1 addition & 1 deletion benchmark/configs/speech/google_speech.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ job_conf:
- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace
- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace
- model: resnet34 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2
- gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default
- eval_interval: 30 # How many rounds to run a testing on the testing set
- rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds
- filter_less: 21 # Remove clients w/ less than 21 samples
Expand Down
469 changes: 285 additions & 184 deletions fedscale/cloud/aggregation/aggregator.py

Large diffs are not rendered by default.

76 changes: 50 additions & 26 deletions fedscale/cloud/aggregation/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,107 @@
import numpy as np
import numpy as np
import torch


class TorchServerOptimizer(object):
"""This is a abstract server optimizer class
Args:
mode (string): mode of gradient aggregation policy
args (distionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py
device (string): Runtime device type
sample_seed (int): Random seed
"""
def __init__(self, mode, args, device, sample_seed=233):

def __init__(self, mode, args, device, sample_seed=233):
self.mode = mode
self.args = args
self.device = device

if mode == 'fed-yogi':
if mode == "fed-yogi":
from fedscale.utils.optimizer.yogi import YoGi

self.gradient_controller = YoGi(
eta=args.yogi_eta, tau=args.yogi_tau, beta=args.yogi_beta, beta2=args.yogi_beta2)
eta=args.yogi_eta,
tau=args.yogi_tau,
beta=args.yogi_beta,
beta2=args.yogi_beta2,
)

def update_round_gradient(
self, last_model, current_model, target_model, client_training_results=None
):
"""update global model based on different policy
def update_round_gradient(self, last_model, current_model, target_model):
""" update global model based on different policy
Args:
last_model (list of tensor weight): A list of global model weight in last round.
current_model (list of tensor weight): A list of global model weight in this round.
target_model (PyTorch or TensorFlow nn module): Aggregated model.
client_training_results list of gradients from every clients, for q-fedavg
"""
if self.mode == 'fed-yogi':
if self.mode == "fed-yogi":
"""
"Adaptive Federated Optimizations",
"Adaptive Federated Optimizations",
Sashank J. Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konecný, Sanjiv Kumar, H. Brendan McMahan,
ICLR 2021.
"""
last_model = [x.to(device=self.device) for x in last_model]
current_model = [x.to(device=self.device) for x in current_model]

diff_weight = self.gradient_controller.update(
[pb-pa for pa, pb in zip(last_model, current_model)])
[pb - pa for pa, pb in zip(last_model, current_model)]
)

new_state_dict = {
name: torch.from_numpy(np.array(last_model[idx] + diff_weight[idx], dtype=np.float32))
name: torch.from_numpy(
np.array(last_model[idx] + diff_weight[idx], dtype=np.float32)
)
for idx, name in enumerate(target_model.state_dict().keys())
}

target_model.load_state_dict(new_state_dict)

elif self.mode == 'q-fedavg':
elif self.mode == "q-fedavg":
"""
"Fair Resource Allocation in Federated Learning", Tian Li, Maziar Sanjabi, Ahmad Beirami, Virginia Smith, ICLR 2020.
"""
learning_rate, qfedq = self.args.learning_rate, self.args.qfed_q
Deltas, hs = None, 0.
Deltas, hs = None, 0.0
last_model = [x.to(device=self.device) for x in last_model]

for result in self.client_training_results:
for result in client_training_results:
# plug in the weight updates into the gradient
grads = [(u - torch.from_numpy(v).to(device=self.device)) * 1.0 /
learning_rate for u, v in zip(last_model, result['update_weight'])]
loss = result['moving_loss']
update_weights = result["update_weight"]
if type(update_weights) is dict:
update_weights = [x for x in update_weights.values()]
weights = [
torch.from_numpy(np.asarray(x, dtype=np.float32)).to(
device=self.device
)
for x in update_weights
]
grads = [
(u - v) * 1.0 / learning_rate for u, v in zip(last_model, weights)
]
loss = result["moving_loss"]

if Deltas is None:
Deltas = [np.float_power(
loss+1e-10, qfedq) * grad for grad in grads]
Deltas = [
np.float_power(loss + 1e-10, qfedq) * grad for grad in grads
]
else:
for idx in range(len(Deltas)):
Deltas[idx] += np.float_power(loss +
1e-10, qfedq) * grads[idx]
Deltas[idx] += np.float_power(loss + 1e-10, qfedq) * grads[idx]

# estimation of the local Lipchitz constant
hs += (qfedq * np.float_power(loss+1e-10, (qfedq-1)) * torch.sum(torch.stack([torch.square(
grad).sum() for grad in grads])) + (1.0/learning_rate) * np.float_power(loss+1e-10, qfedq))
hs += qfedq * np.float_power(loss + 1e-10, (qfedq - 1)) * torch.sum(
torch.stack([torch.square(grad).sum() for grad in grads])
) + (1.0 / learning_rate) * np.float_power(loss + 1e-10, qfedq)

# update global model
for idx, param in enumerate(target_model.parameters()):
param.data = last_model[idx] - Deltas[idx]/(hs+1e-10)
param.data = last_model[idx] - Deltas[idx] / (hs + 1e-10)

else:
# The default optimizer, FedAvg, has been applied in aggregator.py on the fly
Expand Down
Loading

0 comments on commit 0f90918

Please sign in to comment.