-
Notifications
You must be signed in to change notification settings - Fork 4
/
rffl_clientmanager.py
38 lines (31 loc) · 1.67 KB
/
rffl_clientmanager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from core.utils import transform_list_to_grad
from distributed.inflator.inflator_client_manager import FedAVGInflatorClientManager
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../FedML/")))
from fedml_api.distributed.fedavg.FedAvgClientManager import FedAVGClientManager
from fedml_api.distributed.fedavg.message_define import MyMessage
from fedml_api.distributed.fedavg.utils import post_complete_message_to_sweep_process
class RFFLClientManager(FedAVGInflatorClientManager):
def handle_message_receive_model_from_server(self, msg_params):
logging.info("handle_message_receive_model_from_server.")
model_gradients = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
if self.args.is_mobile == 1:
model_gradients = transform_list_to_grad(model_gradients)
self.trainer.update_model_with_gradients(model_gradients)
self.trainer.update_dataset(int(client_index))
self.round_idx += 1
self.__train_with_inflation()
if self.round_idx == self.num_rounds - 1:
post_complete_message_to_sweep_process(self.args)
self.finish()
def __train_with_inflation(self):
logging.info(
"#######training with inflation########### round_id = %d" % self.round_idx
)
weights, local_sample_num = self.trainer.train(self.round_idx)
local_sample_num = int(local_sample_num * self.water_powered_magnification)
self.send_model_to_server(0, weights, local_sample_num)