From 0e918c7ef1cd02ce8281b3d87ef7b5e2ceaf8610 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 10:31:40 -0500 Subject: [PATCH 01/17] Turn off optimizer in executor model update --- fedscale/cloud/execution/executor.py | 2 +- fedscale/cloud/internal/torch_model_adapter.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fedscale/cloud/execution/executor.py b/fedscale/cloud/execution/executor.py index 3edad496..075631ce 100755 --- a/fedscale/cloud/execution/executor.py +++ b/fedscale/cloud/execution/executor.py @@ -397,7 +397,7 @@ def event_monitor(self): elif current_event == commons.UPDATE_MODEL: model_weights = self.deserialize_response(request.data) - self.UpdateModel(model_weights) + self.UpdateModel(model_weights, is_aggregator=False) elif current_event == commons.SHUT_DOWN: self.Stop() diff --git a/fedscale/cloud/internal/torch_model_adapter.py b/fedscale/cloud/internal/torch_model_adapter.py index 0d258ec9..e2769923 100644 --- a/fedscale/cloud/internal/torch_model_adapter.py +++ b/fedscale/cloud/internal/torch_model_adapter.py @@ -20,10 +20,11 @@ def __init__(self, model: torch.nn.Module, optimizer: TorchServerOptimizer = Non self.model = model self.optimizer = optimizer - def set_weights(self, weights: List[np.ndarray]): + def set_weights(self, weights: List[np.ndarray], is_aggregator=True): """ Set the model's weights to the numpy weights array. :param weights: numpy weights array + :param is_aggregator: boolean indicating whether the caller is the aggregator """ last_grad_weights = [param.data.clone() for param in self.model.state_dict().values()] new_state_dict = { @@ -31,7 +32,7 @@ def set_weights(self, weights: List[np.ndarray]): for i, name in enumerate(self.model.state_dict().keys()) } self.model.load_state_dict(new_state_dict) - if self.optimizer: + if self.optimizer and is_aggregator: weights_origin = copy.deepcopy(weights) weights = [torch.tensor(x) for x in weights_origin] self.optimizer.update_round_gradient(last_grad_weights, weights, self.model) From c06b14db19d854e9881baf985447a23eee79d16d Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 11:55:52 -0500 Subject: [PATCH 02/17] Turn off optimizer in executor model update --- fedscale/cloud/execution/executor.py | 205 ++++++++++++++++----------- 1 file changed, 125 insertions(+), 80 deletions(-) diff --git a/fedscale/cloud/execution/executor.py b/fedscale/cloud/execution/executor.py index 075631ce..502eeb70 100755 --- a/fedscale/cloud/execution/executor.py +++ b/fedscale/cloud/execution/executor.py @@ -33,7 +33,9 @@ def __init__(self, args): # initiate the executor log path, and executor ips logger.initiate_client_setting() - self.model_adapter = self.get_client_trainer(args).get_model_adapter(init_model()) + self.model_adapter = self.get_client_trainer(args).get_model_adapter( + init_model() + ) self.args = args self.num_executors = args.num_executors @@ -45,8 +47,7 @@ def __init__(self, args): self.training_sets = self.test_dataset = None # ======== channels ======== - self.aggregator_communicator = ClientConnections( - args.ps_ip, args.ps_port) + self.aggregator_communicator = ClientConnections(args.ps_ip, args.ps_port) # ======== runtime information ======== self.collate_fn = None @@ -56,28 +57,28 @@ def __init__(self, args): self.event_queue = collections.deque() if args.wandb_token != "": - os.environ['WANDB_API_KEY'] = args.wandb_token + os.environ["WANDB_API_KEY"] = args.wandb_token self.wandb = wandb if self.wandb.run is None: - self.wandb.init(project=f'fedscale-{args.job_name}', - name=f'executor{args.this_rank}-{args.time_stamp}', - group=f'{args.time_stamp}') + self.wandb.init( + project=f"fedscale-{args.job_name}", + name=f"executor{args.this_rank}-{args.time_stamp}", + group=f"{args.time_stamp}", + ) else: logging.error("Warning: wandb has already been initialized") - + else: self.wandb = None super(Executor, self).__init__() def setup_env(self): - """Set up experiments environment - """ + """Set up experiments environment""" logging.info(f"(EXECUTOR:{self.this_rank}) is setting up environ ...") self.setup_seed(seed=1) def setup_communication(self): - """Set up grpc connection - """ + """Set up grpc connection""" self.init_control_communication() self.init_data_communication() @@ -101,8 +102,7 @@ def init_control_communication(self): self.aggregator_communicator.connect_to_server() def init_data_communication(self): - """In charge of jumbo data traffics (e.g., fetch training result) - """ + """In charge of jumbo data traffics (e.g., fetch training result)""" pass def init_data(self): @@ -115,20 +115,27 @@ def init_data(self): train_dataset, test_dataset = init_dataset() if self.args.task == "rl": return train_dataset, test_dataset - if self.args.task == 'nlp': + if self.args.task == "nlp": self.collate_fn = collate - elif self.args.task == 'voice': + elif self.args.task == "voice": self.collate_fn = voice_collate_fn # load data partitionxr (entire_train_data) logging.info("Data partitioner starts ...") training_sets = DataPartitioner( - data=train_dataset, args=self.args, numOfClass=self.args.num_class) + data=train_dataset, args=self.args, numOfClass=self.args.num_class + ) training_sets.partition_data_helper( - num_clients=self.args.num_participants, data_map_file=self.args.data_map_file) + num_clients=self.args.num_participants, + data_map_file=self.args.data_map_file, + ) testing_sets = DataPartitioner( - data=test_dataset, args=self.args, numOfClass=self.args.num_class, isTest=True) + data=test_dataset, + args=self.args, + numOfClass=self.args.num_class, + isTest=True, + ) testing_sets.partition_data_helper(num_clients=self.num_executors) logging.info("Data partitioner completes ...") @@ -136,8 +143,7 @@ def init_data(self): return training_sets, testing_sets def run(self): - """Start running the executor by setting up execution and communication environment, and monitoring the grpc message. - """ + """Start running the executor by setting up execution and communication environment, and monitoring the grpc message.""" self.setup_env() self.training_sets, self.testing_sets = self.init_data() self.setup_communication() @@ -184,7 +190,7 @@ def UpdateModel(self, model_weights): """ self.round += 1 - self.model_adapter.set_weights(model_weights) + self.model_adapter.set_weights(model_weights, is_aggregator=False) def Train(self, config): """Load train config and data to start training on that client @@ -196,20 +202,25 @@ def Train(self, config): tuple (int, dictionary): The client id and train result """ - client_id, train_config = config['client_id'], config['task_config'] + client_id, train_config = config["client_id"], config["task_config"] - if 'model' not in config or not config['model']: + if "model" not in config or not config["model"]: raise "The 'model' object must be a non-null value in the training config." client_conf = self.override_conf(train_config) train_res = self.training_handler( - client_id=client_id, conf=client_conf, model=config['model']) + client_id=client_id, conf=client_conf, model=config["model"] + ) # Report execution completion meta information response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( job_api_pb2.CompleteRequest( - client_id=str(client_id), executor_id=self.executor_id, - event=commons.CLIENT_TRAIN, status=True, msg=None, - meta_result=None, data_result=None + client_id=str(client_id), + executor_id=self.executor_id, + event=commons.CLIENT_TRAIN, + status=True, + msg=None, + meta_result=None, + data_result=None, ) ) self.dispatch_worker_events(response) @@ -224,21 +235,24 @@ def Test(self, config): """ test_res = self.testing_handler() - test_res = {'executorId': self.this_rank, 'results': test_res} + test_res = {"executorId": self.this_rank, "results": test_res} # Report execution completion information response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( job_api_pb2.CompleteRequest( - client_id=self.executor_id, executor_id=self.executor_id, - event=commons.MODEL_TEST, status=True, msg=None, - meta_result=None, data_result=self.serialize_response(test_res) + client_id=self.executor_id, + executor_id=self.executor_id, + event=commons.MODEL_TEST, + status=True, + msg=None, + meta_result=None, + data_result=self.serialize_response(test_res), ) ) self.dispatch_worker_events(response) def Stop(self): - """Stop the current executor - """ + """Stop the current executor""" logging.info(f"Terminating the executor ...") self.aggregator_communicator.close_sever_connection() self.received_stop_request = True @@ -255,7 +269,7 @@ def report_executor_info_handler(self): return self.training_sets.getSize() def override_conf(self, config): - """ Override the variable arguments for different client + """Override the variable arguments for different client Args: config (dictionary): The client runtime config. @@ -280,7 +294,7 @@ def get_client_trainer(self, conf): if conf.engine == commons.TENSORFLOW: return TensorflowClient(conf) elif conf.engine == commons.PYTORCH: - if conf.task == 'rl': + if conf.task == "rl": return RLClient(conf) else: return TorchClient(conf) @@ -300,14 +314,21 @@ def training_handler(self, client_id, conf, model): self.model_adapter.set_weights(model) conf.client_id = client_id conf.tokenizer = tokenizer - client_data = self.training_sets if self.args.task == "rl" else \ - select_dataset(client_id, self.training_sets, - batch_size=conf.batch_size, args=self.args, - collate_fn=self.collate_fn - ) + client_data = ( + self.training_sets + if self.args.task == "rl" + else select_dataset( + client_id, + self.training_sets, + batch_size=conf.batch_size, + args=self.args, + collate_fn=self.collate_fn, + ) + ) client = self.get_client_trainer(self.args) train_res = client.train( - client_data=client_data, model=self.model_adapter.get_model(), conf=conf) + client_data=client_data, model=self.model_adapter.get_model(), conf=conf + ) return train_res @@ -321,25 +342,33 @@ def testing_handler(self): dictionary: The test result """ - test_config = self.override_conf({ - 'rank': self.this_rank, - 'memory_capacity': self.args.memory_capacity, - 'tokenizer': tokenizer - }) + test_config = self.override_conf( + { + "rank": self.this_rank, + "memory_capacity": self.args.memory_capacity, + "tokenizer": tokenizer, + } + ) client = self.get_client_trainer(test_config) - data_loader = select_dataset(self.this_rank, self.testing_sets, - batch_size=self.args.test_bsz, args=self.args, - isTest=True, collate_fn=self.collate_fn) + data_loader = select_dataset( + self.this_rank, + self.testing_sets, + batch_size=self.args.test_bsz, + args=self.args, + isTest=True, + collate_fn=self.collate_fn, + ) - test_results = client.test(data_loader, self.model_adapter.get_model(), test_config) + test_results = client.test( + data_loader, self.model_adapter.get_model(), test_config + ) self.log_test_result(test_results) gc.collect() return test_results def client_register(self): - """Register the executor information to the aggregator - """ + """Register the executor information to the aggregator""" start_time = time.time() while time.time() - start_time < 180: try: @@ -348,27 +377,29 @@ def client_register(self): client_id=self.executor_id, executor_id=self.executor_id, executor_info=self.serialize_response( - self.report_executor_info_handler()) + self.report_executor_info_handler() + ), ) ) self.dispatch_worker_events(response) break except Exception as e: - logging.warning(f"Failed to connect to aggregator {e}. Will retry in 5 sec.") + logging.warning( + f"Failed to connect to aggregator {e}. Will retry in 5 sec." + ) time.sleep(5) def client_ping(self): - """Ping the aggregator for new task - """ - response = self.aggregator_communicator.stub.CLIENT_PING(job_api_pb2.PingRequest( - client_id=self.executor_id, - executor_id=self.executor_id - )) + """Ping the aggregator for new task""" + response = self.aggregator_communicator.stub.CLIENT_PING( + job_api_pb2.PingRequest( + client_id=self.executor_id, executor_id=self.executor_id + ) + ) self.dispatch_worker_events(response) def event_monitor(self): - """Activate event handler once receiving new message - """ + """Activate event handler once receiving new message""" logging.info("Start monitoring events ...") self.client_register() @@ -380,24 +411,34 @@ def event_monitor(self): if current_event == commons.CLIENT_TRAIN: train_config = self.deserialize_response(request.meta) train_model = self.deserialize_response(request.data) - train_config['model'] = train_model - train_config['client_id'] = int(train_config['client_id']) + train_config["model"] = train_model + train_config["client_id"] = int(train_config["client_id"]) client_id, train_res = self.Train(train_config) # Upload model updates future_call = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION.future( - job_api_pb2.CompleteRequest(client_id=str(client_id), executor_id=self.executor_id, - event=commons.UPLOAD_MODEL, status=True, msg=None, - meta_result=None, data_result=self.serialize_response(train_res) - )) - future_call.add_done_callback(lambda _response: self.dispatch_worker_events(_response.result())) + job_api_pb2.CompleteRequest( + client_id=str(client_id), + executor_id=self.executor_id, + event=commons.UPLOAD_MODEL, + status=True, + msg=None, + meta_result=None, + data_result=self.serialize_response(train_res), + ) + ) + future_call.add_done_callback( + lambda _response: self.dispatch_worker_events( + _response.result() + ) + ) elif current_event == commons.MODEL_TEST: self.Test(self.deserialize_response(request.meta)) elif current_event == commons.UPDATE_MODEL: model_weights = self.deserialize_response(request.data) - self.UpdateModel(model_weights, is_aggregator=False) + self.UpdateModel(model_weights) elif current_event == commons.SHUT_DOWN: self.Stop() @@ -409,22 +450,26 @@ def event_monitor(self): try: self.client_ping() except Exception as e: - logging.info(f"Caught exception {e} from aggregator, terminating executor {self.this_rank} ...") + logging.info( + f"Caught exception {e} from aggregator, terminating executor {self.this_rank} ..." + ) self.Stop() - def log_test_result(self, test_res): - """Log test results to wandb server if enabled - """ + """Log test results to wandb server if enabled""" acc = round(test_res["top_1"] / test_res["test_len"], 4) acc_5 = round(test_res["top_5"] / test_res["test_len"], 4) test_loss = test_res["test_loss"] / test_res["test_len"] if self.wandb != None: - self.wandb.log({ - 'Test/round_to_top1_accuracy': acc, - 'Test/round_to_top5_accuracy': acc_5, - 'Test/round_to_loss': test_loss, - }, step=self.round) + self.wandb.log( + { + "Test/round_to_top1_accuracy": acc, + "Test/round_to_top5_accuracy": acc_5, + "Test/round_to_loss": test_loss, + }, + step=self.round, + ) + if __name__ == "__main__": executor = Executor(parser.args) From 157b17f23da2f674be589ef8c9dca5396bf9e972 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 16:31:47 -0500 Subject: [PATCH 03/17] Turn off optimizer in each train task at executor --- fedscale/cloud/execution/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/cloud/execution/executor.py b/fedscale/cloud/execution/executor.py index 502eeb70..7a8b9d6c 100755 --- a/fedscale/cloud/execution/executor.py +++ b/fedscale/cloud/execution/executor.py @@ -311,7 +311,7 @@ def training_handler(self, client_id, conf, model): dictionary: The train result """ - self.model_adapter.set_weights(model) + self.model_adapter.set_weights(model, is_aggregator=False) conf.client_id = client_id conf.tokenizer = tokenizer client_data = ( From ba78059c3cf85fdf3de4caf523d0b5e2275b296d Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 17:26:36 -0500 Subject: [PATCH 04/17] Add qfedq argument --- fedscale/cloud/config_parser.py | 391 +++++++++++++++++++------------- 1 file changed, 236 insertions(+), 155 deletions(-) diff --git a/fedscale/cloud/config_parser.py b/fedscale/cloud/config_parser.py index 61196f41..ecf4938c 100644 --- a/fedscale/cloud/config_parser.py +++ b/fedscale/cloud/config_parser.py @@ -3,125 +3,137 @@ from fedscale.cloud import commons parser = argparse.ArgumentParser() -parser.add_argument('--job_name', type=str, default='demo_job') -parser.add_argument('--log_path', type=str, default='./', - help="default path is ../log") -parser.add_argument('--wandb_token', type=str, default="", - help="API key for wandb as login credentials") +parser.add_argument("--job_name", type=str, default="demo_job") +parser.add_argument("--log_path", type=str, default="./", help="default path is ../log") +parser.add_argument( + "--wandb_token", type=str, default="", help="API key for wandb as login credentials" +) # The basic configuration of the cluster -parser.add_argument('--ps_ip', type=str, default='127.0.0.1') -parser.add_argument('--ps_port', type=str, default='29500') -parser.add_argument('--this_rank', type=int, default=1) -parser.add_argument('--connection_timeout', type=int, default=60) -parser.add_argument('--experiment_mode', type=str, - default=commons.SIMULATION_MODE) -parser.add_argument('--engine', type=str, default=commons.PYTORCH, - help="Tensorflow or Pytorch for cloud aggregation") -parser.add_argument('--num_executors', type=int, default=1) -parser.add_argument('--executor_configs', type=str, - default="127.0.0.1:[1]") # seperated by ; +parser.add_argument("--ps_ip", type=str, default="127.0.0.1") +parser.add_argument("--ps_port", type=str, default="29500") +parser.add_argument("--this_rank", type=int, default=1) +parser.add_argument("--connection_timeout", type=int, default=60) +parser.add_argument("--experiment_mode", type=str, default=commons.SIMULATION_MODE) +parser.add_argument( + "--engine", + type=str, + default=commons.PYTORCH, + help="Tensorflow or Pytorch for cloud aggregation", +) +parser.add_argument("--num_executors", type=int, default=1) +parser.add_argument( + "--executor_configs", type=str, default="127.0.0.1:[1]" +) # seperated by ; # Note: In async mode, the num_participants param is treated as the async buffer size. In sync, this is the number # of clients that are selected each round. -parser.add_argument('--num_participants', type=int, default=4) -parser.add_argument('--data_map_file', type=str, default=None) -parser.add_argument('--use_cuda', type=str, default='True') -parser.add_argument('--cuda_device', type=str, default=None) -parser.add_argument('--time_stamp', type=str, default='logs') -parser.add_argument('--task', type=str, default='cv') -parser.add_argument('--device_avail_file', type=str, default=None) -parser.add_argument('--clock_factor', type=float, default=1.0, - help="Refactor the clock time given the profile") +parser.add_argument("--num_participants", type=int, default=4) +parser.add_argument("--data_map_file", type=str, default=None) +parser.add_argument("--use_cuda", type=str, default="True") +parser.add_argument("--cuda_device", type=str, default=None) +parser.add_argument("--time_stamp", type=str, default="logs") +parser.add_argument("--task", type=str, default="cv") +parser.add_argument("--device_avail_file", type=str, default=None) +parser.add_argument( + "--clock_factor", + type=float, + default=1.0, + help="Refactor the clock time given the profile", +) # The configuration of model and dataset -parser.add_argument('--model_zoo', type=str, default='torchcv', - help="model zoo to load the models from", choices=["torchcv", "fedscale-torch-zoo", - "fedscale-tensorflow-zoo"]) -parser.add_argument('--data_dir', type=str, default='~/cifar10/') -parser.add_argument('--device_conf_file', type=str, default='/tmp/client.cfg') -parser.add_argument('--model', type=str, default='shufflenet_v2_x2_0') -parser.add_argument('--data_set', type=str, default='cifar10') -parser.add_argument('--sample_mode', type=str, default='random') -parser.add_argument('--filter_less', type=int, default=32) -parser.add_argument('--filter_more', type=int, default=1e15) -parser.add_argument('--train_uniform', type=bool, default=False) -parser.add_argument('--conf_path', type=str, default='~/dataset/') -parser.add_argument('--overcommitment', type=float, default=1.3) -parser.add_argument('--model_size', type=float, default=65536) -parser.add_argument('--round_threshold', type=float, default=30) -parser.add_argument('--round_penalty', type=float, default=2.0) -parser.add_argument('--clip_bound', type=float, default=0.9) -parser.add_argument('--blacklist_rounds', type=int, default=-1) -parser.add_argument('--blacklist_max_len', type=float, default=0.3) -parser.add_argument('--embedding_file', type=str, - default='glove.840B.300d.txt') -parser.add_argument('--input_shape', type=int, nargs='+', default=[1, 3, 28, 28]) -parser.add_argument('--save_checkpoint', type=bool, default=False) +parser.add_argument( + "--model_zoo", + type=str, + default="torchcv", + help="model zoo to load the models from", + choices=["torchcv", "fedscale-torch-zoo", "fedscale-tensorflow-zoo"], +) +parser.add_argument("--data_dir", type=str, default="~/cifar10/") +parser.add_argument("--device_conf_file", type=str, default="/tmp/client.cfg") +parser.add_argument("--model", type=str, default="shufflenet_v2_x2_0") +parser.add_argument("--data_set", type=str, default="cifar10") +parser.add_argument("--sample_mode", type=str, default="random") +parser.add_argument("--filter_less", type=int, default=32) +parser.add_argument("--filter_more", type=int, default=1e15) +parser.add_argument("--train_uniform", type=bool, default=False) +parser.add_argument("--conf_path", type=str, default="~/dataset/") +parser.add_argument("--overcommitment", type=float, default=1.3) +parser.add_argument("--model_size", type=float, default=65536) +parser.add_argument("--round_threshold", type=float, default=30) +parser.add_argument("--round_penalty", type=float, default=2.0) +parser.add_argument("--clip_bound", type=float, default=0.9) +parser.add_argument("--blacklist_rounds", type=int, default=-1) +parser.add_argument("--blacklist_max_len", type=float, default=0.3) +parser.add_argument("--embedding_file", type=str, default="glove.840B.300d.txt") +parser.add_argument("--input_shape", type=int, nargs="+", default=[1, 3, 28, 28]) +parser.add_argument("--save_checkpoint", type=bool, default=False) # The configuration of different hyper-parameters for training -parser.add_argument('--rounds', type=int, default=50) -parser.add_argument('--local_steps', type=int, default=20) -parser.add_argument('--batch_size', type=int, default=30) -parser.add_argument('--test_bsz', type=int, default=128) -parser.add_argument('--backend', type=str, default="gloo") -parser.add_argument('--learning_rate', type=float, default=5e-2) -parser.add_argument('--min_learning_rate', type=float, default=5e-5) -parser.add_argument('--input_dim', type=int, default=0) -parser.add_argument('--output_dim', type=int, default=0) -parser.add_argument('--dump_epoch', type=int, default=1e10) -parser.add_argument('--decay_factor', type=float, default=0.98) -parser.add_argument('--decay_round', type=float, default=10) -parser.add_argument('--num_loaders', type=int, default=2) -parser.add_argument('--eval_interval', type=int, default=5) -parser.add_argument('--sample_seed', type=int, default=233) # 123 #233 -parser.add_argument('--test_ratio', type=float, default=1.0) -parser.add_argument('--loss_decay', type=float, default=0.2) -parser.add_argument('--exploration_min', type=float, default=0.3) -parser.add_argument('--cut_off_util', type=float, - default=0.05) # 95 percentile - -parser.add_argument('--gradient_policy', type=str, default=None) +parser.add_argument("--rounds", type=int, default=50) +parser.add_argument("--local_steps", type=int, default=20) +parser.add_argument("--batch_size", type=int, default=30) +parser.add_argument("--test_bsz", type=int, default=128) +parser.add_argument("--backend", type=str, default="gloo") +parser.add_argument("--learning_rate", type=float, default=5e-2) +parser.add_argument("--min_learning_rate", type=float, default=5e-5) +parser.add_argument("--input_dim", type=int, default=0) +parser.add_argument("--output_dim", type=int, default=0) +parser.add_argument("--dump_epoch", type=int, default=1e10) +parser.add_argument("--decay_factor", type=float, default=0.98) +parser.add_argument("--decay_round", type=float, default=10) +parser.add_argument("--num_loaders", type=int, default=2) +parser.add_argument("--eval_interval", type=int, default=5) +parser.add_argument("--sample_seed", type=int, default=233) # 123 #233 +parser.add_argument("--test_ratio", type=float, default=1.0) +parser.add_argument("--loss_decay", type=float, default=0.2) +parser.add_argument("--exploration_min", type=float, default=0.3) +parser.add_argument("--cut_off_util", type=float, default=0.05) # 95 percentile + +parser.add_argument("--gradient_policy", type=str, default=None) # for yogi -parser.add_argument('--yogi_eta', type=float, default=3e-3) -parser.add_argument('--yogi_tau', type=float, default=1e-8) -parser.add_argument('--yogi_beta', type=float, default=0.9) -parser.add_argument('--yogi_beta2', type=float, default=0.99) +parser.add_argument("--yogi_eta", type=float, default=3e-3) +parser.add_argument("--yogi_tau", type=float, default=1e-8) +parser.add_argument("--yogi_beta", type=float, default=0.9) +parser.add_argument("--yogi_beta2", type=float, default=0.99) + +# for q-fedavg +parser.add_argument("--qfed_q", type=float, default=1.0) # for prox -parser.add_argument('--proxy_mu', type=float, default=0.1) +parser.add_argument("--proxy_mu", type=float, default=0.1) # for detection -parser.add_argument('--cfg_file', type=str, - default='./utils/rcnn/cfgs/res101.yml') -parser.add_argument('--test_output_dir', type=str, default='./logs/server') -parser.add_argument('--train_size_file', type=str, default='') -parser.add_argument('--test_size_file', type=str, default='') -parser.add_argument('--data_cache', type=str, default='') -parser.add_argument('--backbone', type=str, default='./resnet50.pth') +parser.add_argument("--cfg_file", type=str, default="./utils/rcnn/cfgs/res101.yml") +parser.add_argument("--test_output_dir", type=str, default="./logs/server") +parser.add_argument("--train_size_file", type=str, default="") +parser.add_argument("--test_size_file", type=str, default="") +parser.add_argument("--data_cache", type=str, default="") +parser.add_argument("--backbone", type=str, default="./resnet50.pth") # for malicious -parser.add_argument('--malicious_factor', type=int, default=1e15) +parser.add_argument("--malicious_factor", type=int, default=1e15) # for asynchronous FL -parser.add_argument('--max_concurrency', type=int, default=10) -parser.add_argument('--max_staleness', type=int, default=5) +parser.add_argument("--max_concurrency", type=int, default=10) +parser.add_argument("--max_staleness", type=int, default=5) # for differential privacy -parser.add_argument('--noise_factor', type=float, default=0.1) -parser.add_argument('--clip_threshold', type=float, default=3.0) -parser.add_argument('--target_delta', type=float, default=0.0001) +parser.add_argument("--noise_factor", type=float, default=0.1) +parser.add_argument("--clip_threshold", type=float, default=3.0) +parser.add_argument("--target_delta", type=float, default=0.0001) # for Oort -parser.add_argument('--pacer_delta', type=float, default=5) -parser.add_argument('--pacer_step', type=int, default=20) -parser.add_argument('--exploration_alpha', type=float, default=0.3) -parser.add_argument('--exploration_factor', type=float, default=0.9) -parser.add_argument('--exploration_decay', type=float, default=0.98) -parser.add_argument('--sample_window', type=float, default=5.0) +parser.add_argument("--pacer_delta", type=float, default=5) +parser.add_argument("--pacer_step", type=int, default=20) +parser.add_argument("--exploration_alpha", type=float, default=0.3) +parser.add_argument("--exploration_factor", type=float, default=0.9) +parser.add_argument("--exploration_decay", type=float, default=0.98) +parser.add_argument("--sample_window", type=float, default=5.0) # for albert parser.add_argument( @@ -129,17 +141,26 @@ action="store_true", help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.", ) -parser.add_argument('--clf_block_size', type=int, default=32) +parser.add_argument("--clf_block_size", type=int, default=32) parser.add_argument( - "--mlm", type=bool, default=False, help="Train with masked-language modeling loss instead of language modeling." + "--mlm", + type=bool, + default=False, + help="Train with masked-language modeling loss instead of language modeling.", ) parser.add_argument( - "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + "--mlm_probability", + type=float, + default=0.15, + help="Ratio of tokens to mask for masked language modeling loss", ) parser.add_argument( - "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + "--overwrite_cache", + type=bool, + default=False, + help="Overwrite the cached training and evaluation sets", ) parser.add_argument( "--block_size", @@ -151,83 +172,143 @@ ) -parser.add_argument("--weight_decay", default=0, type=float, - help="Weight decay if we apply some.") -parser.add_argument("--adam_epsilon", default=1e-8, - type=float, help="Epsilon for Adam optimizer.") +parser.add_argument( + "--weight_decay", default=0, type=float, help="Weight decay if we apply some." +) +parser.add_argument( + "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." +) # for tag prediction -parser.add_argument("--vocab_token_size", type=int, - default=10000, help="For vocab token size") -parser.add_argument("--vocab_tag_size", type=int, - default=500, help="For vocab tag size") +parser.add_argument( + "--vocab_token_size", type=int, default=10000, help="For vocab token size" +) +parser.add_argument( + "--vocab_tag_size", type=int, default=500, help="For vocab tag size" +) # for rl example parser.add_argument("--epsilon", type=float, default=0.9, help="greedy policy") parser.add_argument("--gamma", type=float, default=0.9, help="reward discount") -parser.add_argument("--memory_capacity", type=int, - default=2000, help="memory capacity") -parser.add_argument("--target_replace_iter", type=int, - default=15, help="update frequency") +parser.add_argument("--memory_capacity", type=int, default=2000, help="memory capacity") +parser.add_argument( + "--target_replace_iter", type=int, default=15, help="update frequency" +) parser.add_argument("--n_actions", type=int, default=2, help="action number") parser.add_argument("--n_states", type=int, default=4, help="state number") -parser.add_argument("--num_classes", type=int, default=35, - help="For number of classes of the dataset") +parser.add_argument( + "--num_classes", type=int, default=35, help="For number of classes of the dataset" +) # for voice -parser.add_argument('--train-manifest', metavar='DIR', - help='path to train manifest csv', default='data/train_manifest.csv') -parser.add_argument('--test-manifest', metavar='DIR', - help='path to test manifest csv', default='data/test_manifest.csv') -parser.add_argument('--sample-rate', default=16000, - type=int, help='Sample rate') -parser.add_argument('--labels-path', default='labels.json', - help='Contains all characters for transcription') -parser.add_argument('--window-size', default=.02, type=float, - help='Window size for spectrogram in seconds') -parser.add_argument('--window-stride', default=.01, type=float, - help='Window stride for spectrogram in seconds') -parser.add_argument('--window', default='hamming', - help='Window type for spectrogram generation') -parser.add_argument('--hidden-size', default=256, - type=int, help='Hidden size of RNNs') -parser.add_argument('--hidden-layers', default=7, - type=int, help='Number of RNN layers') -parser.add_argument('--rnn-type', default='lstm', - help='Type of the RNN. rnn|gru|lstm are supported') -parser.add_argument('--finetune', dest='finetune', action='store_true', - help='Finetune the model from checkpoint "continue_from"') -parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true', - help='Use random tempo and gain perturbations.') -parser.add_argument('--spec-augment', dest='spec_augment', action='store_true', - help='Use simple spectral augmentation on mel spectograms.') -parser.add_argument('--noise-dir', default=None, - help='Directory to inject noise into audio. If default, noise Inject not added') -parser.add_argument('--noise-prob', default=0.4, - help='Probability of noise being added per sample') -parser.add_argument('--noise-min', default=0.0, - help='Minimum noise level to sample from. (1.0 means all noise, not original signal)', type=float) -parser.add_argument('--noise-max', default=0.5, - help='Maximum noise levels to sample from. Maximum 1.0', type=float) -parser.add_argument('--no-bidirectional', dest='bidirectional', action='store_false', default=True, - help='Turn off bi-directional RNNs, introduces lookahead convolution') +parser.add_argument( + "--train-manifest", + metavar="DIR", + help="path to train manifest csv", + default="data/train_manifest.csv", +) +parser.add_argument( + "--test-manifest", + metavar="DIR", + help="path to test manifest csv", + default="data/test_manifest.csv", +) +parser.add_argument("--sample-rate", default=16000, type=int, help="Sample rate") +parser.add_argument( + "--labels-path", + default="labels.json", + help="Contains all characters for transcription", +) +parser.add_argument( + "--window-size", + default=0.02, + type=float, + help="Window size for spectrogram in seconds", +) +parser.add_argument( + "--window-stride", + default=0.01, + type=float, + help="Window stride for spectrogram in seconds", +) +parser.add_argument( + "--window", default="hamming", help="Window type for spectrogram generation" +) +parser.add_argument("--hidden-size", default=256, type=int, help="Hidden size of RNNs") +parser.add_argument("--hidden-layers", default=7, type=int, help="Number of RNN layers") +parser.add_argument( + "--rnn-type", default="lstm", help="Type of the RNN. rnn|gru|lstm are supported" +) +parser.add_argument( + "--finetune", + dest="finetune", + action="store_true", + help='Finetune the model from checkpoint "continue_from"', +) +parser.add_argument( + "--speed-volume-perturb", + dest="speed_volume_perturb", + action="store_true", + help="Use random tempo and gain perturbations.", +) +parser.add_argument( + "--spec-augment", + dest="spec_augment", + action="store_true", + help="Use simple spectral augmentation on mel spectograms.", +) +parser.add_argument( + "--noise-dir", + default=None, + help="Directory to inject noise into audio. If default, noise Inject not added", +) +parser.add_argument( + "--noise-prob", default=0.4, help="Probability of noise being added per sample" +) +parser.add_argument( + "--noise-min", + default=0.0, + help="Minimum noise level to sample from. (1.0 means all noise, not original signal)", + type=float, +) +parser.add_argument( + "--noise-max", + default=0.5, + help="Maximum noise levels to sample from. Maximum 1.0", + type=float, +) +parser.add_argument( + "--no-bidirectional", + dest="bidirectional", + action="store_false", + default=True, + help="Turn off bi-directional RNNs, introduces lookahead convolution", +) args, unknown = parser.parse_known_args() args.use_cuda = eval(args.use_cuda) -datasetCategories = {'Mnist': 10, 'cifar10': 10, "imagenet": 1000, 'emnist': 47, - 'openImg': 596, 'google_speech': 35, 'femnist': 62, 'yelp': 5 - } +datasetCategories = { + "Mnist": 10, + "cifar10": 10, + "imagenet": 1000, + "emnist": 47, + "openImg": 596, + "google_speech": 35, + "femnist": 62, + "yelp": 5, +} # Profiled relative speech w.r.t. Mobilenet -model_factor = {'shufflenet': 0.0644/0.0554, - 'albert': 0.335/0.0554, - 'resnet': 0.135/0.0554, - } +model_factor = { + "shufflenet": 0.0644 / 0.0554, + "albert": 0.335 / 0.0554, + "resnet": 0.135 / 0.0554, +} args.num_class = datasetCategories.get(args.data_set, args.num_classes) for model_name in model_factor: From d4306f47a7d33f864e5c724104afa8e9bfac54f8 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 17:36:41 -0500 Subject: [PATCH 05/17] Add qfedq argument --- fedscale/cloud/aggregation/aggregator.py | 5 +- fedscale/cloud/aggregation/optimizers.py | 70 ++++++++++++------- .../cloud/internal/torch_model_adapter.py | 5 +- 3 files changed, 51 insertions(+), 29 deletions(-) diff --git a/fedscale/cloud/aggregation/aggregator.py b/fedscale/cloud/aggregation/aggregator.py index 05746f80..f69136a3 100755 --- a/fedscale/cloud/aggregation/aggregator.py +++ b/fedscale/cloud/aggregation/aggregator.py @@ -460,7 +460,10 @@ def update_weight_aggregation(self, results): self.model_weights = [weight + update_weights[i] for i, weight in enumerate(self.model_weights)] if self._is_last_result_in_round(): self.model_weights = [np.divide(weight, self.tasks_round) for weight in self.model_weights] - self.model_wrapper.set_weights(copy.deepcopy(self.model_weights)) + self.model_wrapper.set_weights( + copy.deepcopy(self.model_weights), + self.client_training_results + ) def aggregate_test_result(self): diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index e34329fa..2343c3ab 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -1,8 +1,10 @@ -import numpy as np +import numpy as np import torch + + class TorchServerOptimizer(object): """This is a abstract server optimizer class - + Args: mode (string): mode of gradient aggregation policy args (distionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py @@ -10,29 +12,37 @@ class TorchServerOptimizer(object): sample_seed (int): Random seed """ - def __init__(self, mode, args, device, sample_seed=233): + def __init__(self, mode, args, device, sample_seed=233): self.mode = mode self.args = args self.device = device - if mode == 'fed-yogi': + if mode == "fed-yogi": from fedscale.utils.optimizer.yogi import YoGi + self.gradient_controller = YoGi( - eta=args.yogi_eta, tau=args.yogi_tau, beta=args.yogi_beta, beta2=args.yogi_beta2) + eta=args.yogi_eta, + tau=args.yogi_tau, + beta=args.yogi_beta, + beta2=args.yogi_beta2, + ) + + def update_round_gradient( + self, last_model, current_model, target_model, client_training_results=None + ): + """update global model based on different policy - def update_round_gradient(self, last_model, current_model, target_model): - """ update global model based on different policy - Args: last_model (list of tensor weight): A list of global model weight in last round. current_model (list of tensor weight): A list of global model weight in this round. target_model (PyTorch or TensorFlow nn module): Aggregated model. - + client_training_results list of gradients from every clients, for q-fedavg + """ - if self.mode == 'fed-yogi': + if self.mode == "fed-yogi": """ - "Adaptive Federated Optimizations", + "Adaptive Federated Optimizations", Sashank J. Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konecný, Sanjiv Kumar, H. Brendan McMahan, ICLR 2021. """ @@ -40,44 +50,52 @@ def update_round_gradient(self, last_model, current_model, target_model): current_model = [x.to(device=self.device) for x in current_model] diff_weight = self.gradient_controller.update( - [pb-pa for pa, pb in zip(last_model, current_model)]) + [pb - pa for pa, pb in zip(last_model, current_model)] + ) new_state_dict = { - name: torch.from_numpy(np.array(last_model[idx] + diff_weight[idx], dtype=np.float32)) + name: torch.from_numpy( + np.array(last_model[idx] + diff_weight[idx], dtype=np.float32) + ) for idx, name in enumerate(target_model.state_dict().keys()) } target_model.load_state_dict(new_state_dict) - elif self.mode == 'q-fedavg': + elif self.mode == "q-fedavg": """ "Fair Resource Allocation in Federated Learning", Tian Li, Maziar Sanjabi, Ahmad Beirami, Virginia Smith, ICLR 2020. """ learning_rate, qfedq = self.args.learning_rate, self.args.qfed_q - Deltas, hs = None, 0. + Deltas, hs = None, 0.0 last_model = [x.to(device=self.device) for x in last_model] - for result in self.client_training_results: + for result in client_training_results: # plug in the weight updates into the gradient - grads = [(u - torch.from_numpy(v).to(device=self.device)) * 1.0 / - learning_rate for u, v in zip(last_model, result['update_weight'])] - loss = result['moving_loss'] + grads = [ + (u - torch.from_numpy(v).to(device=self.device)) + * 1.0 + / learning_rate + for u, v in zip(last_model, result["update_weight"]) + ] + loss = result["moving_loss"] if Deltas is None: - Deltas = [np.float_power( - loss+1e-10, qfedq) * grad for grad in grads] + Deltas = [ + np.float_power(loss + 1e-10, qfedq) * grad for grad in grads + ] else: for idx in range(len(Deltas)): - Deltas[idx] += np.float_power(loss + - 1e-10, qfedq) * grads[idx] + Deltas[idx] += np.float_power(loss + 1e-10, qfedq) * grads[idx] # estimation of the local Lipchitz constant - hs += (qfedq * np.float_power(loss+1e-10, (qfedq-1)) * torch.sum(torch.stack([torch.square( - grad).sum() for grad in grads])) + (1.0/learning_rate) * np.float_power(loss+1e-10, qfedq)) + hs += qfedq * np.float_power(loss + 1e-10, (qfedq - 1)) * torch.sum( + torch.stack([torch.square(grad).sum() for grad in grads]) + ) + (1.0 / learning_rate) * np.float_power(loss + 1e-10, qfedq) # update global model for idx, param in enumerate(target_model.parameters()): - param.data = last_model[idx] - Deltas[idx]/(hs+1e-10) + param.data = last_model[idx] - Deltas[idx] / (hs + 1e-10) else: # The default optimizer, FedAvg, has been applied in aggregator.py on the fly diff --git a/fedscale/cloud/internal/torch_model_adapter.py b/fedscale/cloud/internal/torch_model_adapter.py index e2769923..8c73e10d 100644 --- a/fedscale/cloud/internal/torch_model_adapter.py +++ b/fedscale/cloud/internal/torch_model_adapter.py @@ -20,11 +20,12 @@ def __init__(self, model: torch.nn.Module, optimizer: TorchServerOptimizer = Non self.model = model self.optimizer = optimizer - def set_weights(self, weights: List[np.ndarray], is_aggregator=True): + def set_weights(self, weights: List[np.ndarray], is_aggregator=True, client_training_results=None): """ Set the model's weights to the numpy weights array. :param weights: numpy weights array :param is_aggregator: boolean indicating whether the caller is the aggregator + :param client_training_results: list of gradients from every clients, for q-fedavg """ last_grad_weights = [param.data.clone() for param in self.model.state_dict().values()] new_state_dict = { @@ -35,7 +36,7 @@ def set_weights(self, weights: List[np.ndarray], is_aggregator=True): if self.optimizer and is_aggregator: weights_origin = copy.deepcopy(weights) weights = [torch.tensor(x) for x in weights_origin] - self.optimizer.update_round_gradient(last_grad_weights, weights, self.model) + self.optimizer.update_round_gradient(last_grad_weights, weights, self.model, client_training_results) def get_weights(self) -> List[np.ndarray]: """ From 65bedc012f65b39d5eea297abb4691e2a6722caf Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 17:49:41 -0500 Subject: [PATCH 06/17] Add qfedq argument --- fedscale/cloud/aggregation/aggregator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/cloud/aggregation/aggregator.py b/fedscale/cloud/aggregation/aggregator.py index f69136a3..9c35c849 100755 --- a/fedscale/cloud/aggregation/aggregator.py +++ b/fedscale/cloud/aggregation/aggregator.py @@ -462,7 +462,7 @@ def update_weight_aggregation(self, results): self.model_weights = [np.divide(weight, self.tasks_round) for weight in self.model_weights] self.model_wrapper.set_weights( copy.deepcopy(self.model_weights), - self.client_training_results + client_training_results=self.client_training_results ) From d4c482b6e6783046ac05bb0337ab041aab5e9edf Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 18:36:40 -0500 Subject: [PATCH 07/17] Add qfedq argument --- fedscale/cloud/aggregation/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index 2343c3ab..bff95594 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -76,7 +76,7 @@ def update_round_gradient( (u - torch.from_numpy(v).to(device=self.device)) * 1.0 / learning_rate - for u, v in zip(last_model, result["update_weight"]) + for u, v in zip(last_model, result["update_weight"].values()) ] loss = result["moving_loss"] From 0537a95f907928b2be41a3d06f2e1eda65addf2f Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 19:09:41 -0500 Subject: [PATCH 08/17] Add qfedq argument --- fedscale/cloud/aggregation/optimizers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index bff95594..b758a522 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -72,11 +72,15 @@ def update_round_gradient( for result in client_training_results: # plug in the weight updates into the gradient + update_weights = result["update_weights"] + if type(update_weights) is dict: + update_weights = [x for x in update_weights.values()] + weights = [torch.tensor(x).to(device=self.device) for x in update_weights] grads = [ - (u - torch.from_numpy(v).to(device=self.device)) + (u - v) * 1.0 / learning_rate - for u, v in zip(last_model, result["update_weight"].values()) + for u, v in zip(last_model, weights) ] loss = result["moving_loss"] From 0c56dd141989ace9225ed4e94c046e5dc73e747b Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 19:13:28 -0500 Subject: [PATCH 09/17] Fix key error --- fedscale/cloud/aggregation/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index b758a522..7cd80706 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -72,7 +72,7 @@ def update_round_gradient( for result in client_training_results: # plug in the weight updates into the gradient - update_weights = result["update_weights"] + update_weights = result["update_weight"] if type(update_weights) is dict: update_weights = [x for x in update_weights.values()] weights = [torch.tensor(x).to(device=self.device) for x in update_weights] From 2778ae0266978b8586f9fe1001d0b433892ee127 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 20:10:22 -0500 Subject: [PATCH 10/17] Try remove 'to device' --- fedscale/cloud/aggregation/optimizers.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index 7cd80706..eaceb5e3 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -46,8 +46,8 @@ def update_round_gradient( Sashank J. Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konecný, Sanjiv Kumar, H. Brendan McMahan, ICLR 2021. """ - last_model = [x.to(device=self.device) for x in last_model] - current_model = [x.to(device=self.device) for x in current_model] + # last_model = [x.to(device=self.device) for x in last_model] + # current_model = [x.to(device=self.device) for x in current_model] diff_weight = self.gradient_controller.update( [pb - pa for pa, pb in zip(last_model, current_model)] @@ -75,12 +75,14 @@ def update_round_gradient( update_weights = result["update_weight"] if type(update_weights) is dict: update_weights = [x for x in update_weights.values()] - weights = [torch.tensor(x).to(device=self.device) for x in update_weights] + weights = [ + torch.from_numpy(np.asarray(x, dtype=np.float32)).to( + device=self.device + ) + for x in update_weights + ] grads = [ - (u - v) - * 1.0 - / learning_rate - for u, v in zip(last_model, weights) + (u - v) * 1.0 / learning_rate for u, v in zip(last_model, weights) ] loss = result["moving_loss"] From 043af94a99020e6426807a970bdd2bfb448e59db Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 20:39:33 -0500 Subject: [PATCH 11/17] Make model used in test directly transferred from aggregator --- fedscale/cloud/aggregation/aggregator.py | 468 ++++++++++++++--------- fedscale/cloud/execution/executor.py | 13 +- 2 files changed, 292 insertions(+), 189 deletions(-) diff --git a/fedscale/cloud/aggregation/aggregator.py b/fedscale/cloud/aggregation/aggregator.py index 9c35c849..a52d7f39 100755 --- a/fedscale/cloud/aggregation/aggregator.py +++ b/fedscale/cloud/aggregation/aggregator.py @@ -44,13 +44,12 @@ def __init__(self, args): logging.info(f"Job args {args}") self.args = args self.experiment_mode = args.experiment_mode - self.device = args.cuda_device if args.use_cuda else torch.device( - 'cpu') + self.device = args.cuda_device if args.use_cuda else torch.device("cpu") # ======== env information ======== self.this_rank = 0 - self.global_virtual_clock = 0. - self.round_duration = 0. + self.global_virtual_clock = 0.0 + self.round_duration = 0.0 self.resource_manager = ResourceManager(self.experiment_mode) self.client_manager = self.init_client_manager(args=args) @@ -61,7 +60,8 @@ def __init__(self, args): # all weights including bias/#_batch_tracked (e.g., state_dict) self.model_weights = None self.temp_model_path = os.path.join( - logger.logDir, 'model_'+str(args.this_rank)+".npy") + logger.logDir, "model_" + str(args.this_rank) + ".npy" + ) self.last_saved_round = 0 # ======== channels ======== @@ -86,7 +86,7 @@ def __init__(self, args): self.sampled_executors = [] self.round_stragglers = [] - self.model_update_size = 0. + self.model_update_size = 0.0 self.collate_fn = None self.round = 0 @@ -101,27 +101,36 @@ def __init__(self, args): # number of registered executors self.registered_executor_info = set() self.test_result_accumulator = [] - self.testing_history = {'data_set': args.data_set, 'model': args.model, 'sample_mode': args.sample_mode, - 'gradient_policy': args.gradient_policy, 'task': args.task, - 'perf': collections.OrderedDict()} + self.testing_history = { + "data_set": args.data_set, + "model": args.model, + "sample_mode": args.sample_mode, + "gradient_policy": args.gradient_policy, + "task": args.task, + "perf": collections.OrderedDict(), + } self.log_writer = SummaryWriter(log_dir=logger.logDir) if args.wandb_token != "": - os.environ['WANDB_API_KEY'] = args.wandb_token + os.environ["WANDB_API_KEY"] = args.wandb_token self.wandb = wandb if self.wandb.run is None: - self.wandb.init(project=f'fedscale-{args.job_name}', - name=f'aggregator{args.this_rank}-{args.time_stamp}', - group=f'{args.time_stamp}') - self.wandb.config.update({ - "num_participants": args.num_participants, - "data_set": args.data_set, - "model": args.model, - "gradient_policy": args.gradient_policy, - "eval_interval": args.eval_interval, - "rounds": args.rounds, - "batch_size": args.batch_size, - "use_cuda": args.use_cuda - }) + self.wandb.init( + project=f"fedscale-{args.job_name}", + name=f"aggregator{args.this_rank}-{args.time_stamp}", + group=f"{args.time_stamp}", + ) + self.wandb.config.update( + { + "num_participants": args.num_participants, + "data_set": args.data_set, + "model": args.model, + "gradient_policy": args.gradient_policy, + "eval_interval": args.eval_interval, + "rounds": args.rounds, + "batch_size": args.batch_size, + "use_cuda": args.use_cuda, + } + ) else: logging.error("Warning: wandb has already been initialized") # self.wandb.run.name = f'{args.job_name}-{args.time_stamp}' @@ -132,8 +141,7 @@ def __init__(self, args): self.init_task_context() def setup_env(self): - """Set up experiments environment and server optimizer - """ + """Set up experiments environment and server optimizer""" self.setup_seed(seed=1) def setup_seed(self, seed=1): @@ -157,8 +165,8 @@ def init_control_communication(self): if self.experiment_mode == commons.SIMULATION_MODE: num_of_executors = 0 for ip_numgpu in self.args.executor_configs.split("="): - ip, numgpu = ip_numgpu.split(':') - for numexe in numgpu.strip()[1:-1].split(','): + ip, numgpu = ip_numgpu.split(":") + for numexe in numgpu.strip()[1:-1].split(","): for _ in range(int(numexe.strip())): num_of_executors += 1 self.executors = list(range(num_of_executors)) @@ -169,22 +177,22 @@ def init_control_communication(self): self.grpc_server = grpc.server( futures.ThreadPoolExecutor(max_workers=20), options=[ - ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH), - ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), + ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH), + ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH), ], ) - job_api_pb2_grpc.add_JobServiceServicer_to_server( - self, self.grpc_server) - port = '[::]:{}'.format(self.args.ps_port) + job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self.grpc_server) + port = "[::]:{}".format(self.args.ps_port) - logging.info(f'%%%%%%%%%% Opening aggregator server using port {port} %%%%%%%%%%') + logging.info( + f"%%%%%%%%%% Opening aggregator server using port {port} %%%%%%%%%%" + ) self.grpc_server.add_insecure_port(port) self.grpc_server.start() def init_data_communication(self): - """For jumbo traffics (e.g., training results). - """ + """For jumbo traffics (e.g., training results).""" pass def init_model(self): @@ -195,22 +203,24 @@ def init_model(self): self.model_wrapper = TorchModelAdapter( init_model(), optimizer=TorchServerOptimizer( - self.args.gradient_policy, self.args, self.device)) + self.args.gradient_policy, self.args, self.device + ), + ) else: raise ValueError(f"{self.args.engine} is not a supported engine.") self.model_weights = self.model_wrapper.get_weights() def init_task_context(self): - """Initiate execution context for specific tasks - """ + """Initiate execution context for specific tasks""" if self.args.task == "detection": cfg_from_file(self.args.cfg_file) np.random.seed(self.cfg.RNG_SEED) self.imdb, _, _, _ = combined_roidb( - "voc_2007_test", ['DATA_DIR', self.args.data_dir], server=True) + "voc_2007_test", ["DATA_DIR", self.args.data_dir], server=True + ) def init_client_manager(self, args): - """ Initialize client sampler + """Initialize client sampler Args: args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py @@ -247,7 +257,7 @@ def load_client_profile(self, file_path): """ global_client_profile = {} if os.path.exists(file_path): - with open(file_path, 'rb') as fin: + with open(file_path, "rb") as fin: # {client_id: [computer, bandwidth]} global_client_profile = pickle.load(fin) @@ -262,28 +272,37 @@ def client_register_handler(self, executorId, info): """ logging.info(f"Loading {len(info['size'])} client traces ...") - for _size in info['size']: + for _size in info["size"]: # since the worker rankId starts from 1, we also configure the initial dataId as 1 - mapped_id = (self.num_of_clients + 1) % len( - self.client_profiles) if len(self.client_profiles) > 0 else 1 + mapped_id = ( + (self.num_of_clients + 1) % len(self.client_profiles) + if len(self.client_profiles) > 0 + else 1 + ) systemProfile = self.client_profiles.get( - mapped_id, {'computation': 1.0, 'communication': 1.0}) + mapped_id, {"computation": 1.0, "communication": 1.0} + ) client_id = ( - self.num_of_clients + 1) if self.experiment_mode == commons.SIMULATION_MODE else executorId + (self.num_of_clients + 1) + if self.experiment_mode == commons.SIMULATION_MODE + else executorId + ) self.client_manager.register_client( - executorId, client_id, size=_size, speed=systemProfile) + executorId, client_id, size=_size, speed=systemProfile + ) self.client_manager.registerDuration( client_id, batch_size=self.args.batch_size, local_steps=self.args.local_steps, upload_size=self.model_update_size, - download_size=self.model_update_size + download_size=self.model_update_size, ) self.num_of_clients += 1 - logging.info("Info of all feasible clients {}".format( - self.client_manager.getDataInfo())) + logging.info( + "Info of all feasible clients {}".format(self.client_manager.getDataInfo()) + ) def executor_info_handler(self, executorId, info): """Handler for register executor info and it will start the round after number of @@ -296,12 +315,12 @@ def executor_info_handler(self, executorId, info): """ self.registered_executor_info.add(executorId) logging.info( - f"Received executor {executorId} information, {len(self.registered_executor_info)}/{len(self.executors)}") + f"Received executor {executorId} information, {len(self.registered_executor_info)}/{len(self.executors)}" + ) # In this simulation, we run data split on each worker, so collecting info from one executor is enough # Waiting for data information from executors, or timeout if self.experiment_mode == commons.SIMULATION_MODE: - if len(self.registered_executor_info) == len(self.executors): self.client_register_handler(executorId, info) # start to sample clients @@ -335,42 +354,58 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): for client_to_run in sampled_clients: client_cfg = self.client_conf.get(client_to_run, self.args) - exe_cost = self.client_manager.get_completion_time(client_to_run, - batch_size=client_cfg.batch_size, - local_steps=client_cfg.local_steps, - upload_size=self.model_update_size, - download_size=self.model_update_size) + exe_cost = self.client_manager.get_completion_time( + client_to_run, + batch_size=client_cfg.batch_size, + local_steps=client_cfg.local_steps, + upload_size=self.model_update_size, + download_size=self.model_update_size, + ) - roundDuration = exe_cost['computation'] + \ - exe_cost['communication'] + roundDuration = exe_cost["computation"] + exe_cost["communication"] # if the client is not active by the time of collection, we consider it is lost in this round - if self.client_manager.isClientActive(client_to_run, roundDuration + self.global_virtual_clock): + if self.client_manager.isClientActive( + client_to_run, roundDuration + self.global_virtual_clock + ): sampledClientsReal.append(client_to_run) completionTimes.append(roundDuration) completed_client_clock[client_to_run] = exe_cost - num_clients_to_collect = min( - num_clients_to_collect, len(completionTimes)) + num_clients_to_collect = min(num_clients_to_collect, len(completionTimes)) # 2. get the top-k completions to remove stragglers workers_sorted_by_completion_time = sorted( - range(len(completionTimes)), key=lambda k: completionTimes[k]) + range(len(completionTimes)), key=lambda k: completionTimes[k] + ) top_k_index = workers_sorted_by_completion_time[:num_clients_to_collect] clients_to_run = [sampledClientsReal[k] for k in top_k_index] - stragglers = [sampledClientsReal[k] - for k in workers_sorted_by_completion_time[num_clients_to_collect:]] + stragglers = [ + sampledClientsReal[k] + for k in workers_sorted_by_completion_time[num_clients_to_collect:] + ] round_duration = completionTimes[top_k_index[-1]] completionTimes.sort() - return (clients_to_run, stragglers, - completed_client_clock, round_duration, - completionTimes[:num_clients_to_collect]) + return ( + clients_to_run, + stragglers, + completed_client_clock, + round_duration, + completionTimes[:num_clients_to_collect], + ) else: completed_client_clock = { - client: {'computation': 1, 'communication': 1} for client in sampled_clients} + client: {"computation": 1, "communication": 1} + for client in sampled_clients + } completionTimes = [1 for c in sampled_clients] - return (sampled_clients, sampled_clients, completed_client_clock, - 1, completionTimes) + return ( + sampled_clients, + sampled_clients, + completed_client_clock, + 1, + completionTimes, + ) def run(self): """Start running the aggregator server by setting up execution @@ -378,14 +413,16 @@ def run(self): """ self.setup_env() self.client_profiles = self.load_client_profile( - file_path=self.args.device_conf_file) - + file_path=self.args.device_conf_file + ) + self.init_control_communication() self.init_data_communication() self.init_model() - self.model_update_size = sys.getsizeof( - pickle.dumps(self.model_wrapper)) / 1024.0 * 8. # kbits + self.model_update_size = ( + sys.getsizeof(pickle.dumps(self.model_wrapper)) / 1024.0 * 8.0 + ) # kbits self.event_monitor() self.stop() @@ -407,9 +444,11 @@ def select_participants(self, select_num_participants, overcommitment=1.3): list of int: The list of sampled clients id. """ - return sorted(self.client_manager.select_participants( - int(select_num_participants * overcommitment), - cur_time=self.global_virtual_clock), + return sorted( + self.client_manager.select_participants( + int(select_num_participants * overcommitment), + cur_time=self.global_virtual_clock, + ), ) def client_completion_handler(self, results): @@ -424,19 +463,20 @@ def client_completion_handler(self, results): # -results = {'client_id':client_id, 'update_weight': model_param, 'moving_loss': round_train_loss, # 'trained_size': count, 'wall_duration': time_cost, 'success': is_success 'utility': utility} - if self.args.gradient_policy in ['q-fedavg']: + if self.args.gradient_policy in ["q-fedavg"]: self.client_training_results.append(results) # Feed metrics to client sampler - self.stats_util_accumulator.append(results['utility']) - self.loss_accumulator.append(results['moving_loss']) - - self.client_manager.register_feedback(results['client_id'], results['utility'], - auxi=math.sqrt( - results['moving_loss']), - time_stamp=self.round, - duration=self.virtual_client_clock[results['client_id']]['computation'] + - self.virtual_client_clock[results['client_id']]['communication'] - ) + self.stats_util_accumulator.append(results["utility"]) + self.loss_accumulator.append(results["moving_loss"]) + + self.client_manager.register_feedback( + results["client_id"], + results["utility"], + auxi=math.sqrt(results["moving_loss"]), + time_stamp=self.round, + duration=self.virtual_client_clock[results["client_id"]]["computation"] + + self.virtual_client_clock[results["client_id"]]["communication"], + ) # ================== Aggregate weights ====================== self.update_lock.acquire() @@ -451,20 +491,24 @@ def update_weight_aggregation(self, results): :param results: the results collected from the client. """ - update_weights = results['update_weight'] + update_weights = results["update_weight"] if type(update_weights) is dict: update_weights = [x for x in update_weights.values()] if self._is_first_result_in_round(): self.model_weights = update_weights else: - self.model_weights = [weight + update_weights[i] for i, weight in enumerate(self.model_weights)] + self.model_weights = [ + weight + update_weights[i] + for i, weight in enumerate(self.model_weights) + ] if self._is_last_result_in_round(): - self.model_weights = [np.divide(weight, self.tasks_round) for weight in self.model_weights] + self.model_weights = [ + np.divide(weight, self.tasks_round) for weight in self.model_weights + ] self.model_wrapper.set_weights( copy.deepcopy(self.model_weights), - client_training_results=self.client_training_results - ) - + client_training_results=self.client_training_results, + ) def aggregate_test_result(self): accumulator = self.test_result_accumulator[0] @@ -473,33 +517,45 @@ def aggregate_test_result(self): for key in accumulator: if key == "boxes": for j in range(596): - accumulator[key][j] = accumulator[key][j] + \ - self.test_result_accumulator[i][key][j] + accumulator[key][j] = ( + accumulator[key][j] + + self.test_result_accumulator[i][key][j] + ) else: accumulator[key] += self.test_result_accumulator[i][key] else: for key in accumulator: accumulator[key] += self.test_result_accumulator[i][key] - self.testing_history['perf'][self.round] = {'round': self.round, 'clock': self.global_virtual_clock} + self.testing_history["perf"][self.round] = { + "round": self.round, + "clock": self.global_virtual_clock, + } for metric_name in accumulator.keys(): - if metric_name == 'test_loss': - self.testing_history['perf'][self.round]['loss'] = accumulator['test_loss'] \ - if self.args.task == "detection" else accumulator['test_loss'] / accumulator['test_len'] - elif metric_name not in ['test_len']: - self.testing_history['perf'][self.round][metric_name] \ - = accumulator[metric_name] / accumulator['test_len'] - - round_perf = self.testing_history['perf'][self.round] + if metric_name == "test_loss": + self.testing_history["perf"][self.round]["loss"] = ( + accumulator["test_loss"] + if self.args.task == "detection" + else accumulator["test_loss"] / accumulator["test_len"] + ) + elif metric_name not in ["test_len"]: + self.testing_history["perf"][self.round][metric_name] = ( + accumulator[metric_name] / accumulator["test_len"] + ) + + round_perf = self.testing_history["perf"][self.round] logging.info( - "FL Testing in round: {}, virtual_clock: {}, results: {}" - .format(self.round, self.global_virtual_clock, round_perf)) + "FL Testing in round: {}, virtual_clock: {}, results: {}".format( + self.round, self.global_virtual_clock, round_perf + ) + ) def update_default_task_config(self): - """Update the default task configuration after each round - """ + """Update the default task configuration after each round""" if self.round % self.args.decay_round == 0: self.args.learning_rate = max( - self.args.learning_rate * self.args.decay_factor, self.args.min_learning_rate) + self.args.learning_rate * self.args.decay_factor, + self.args.min_learning_rate, + ) def round_completion_handler(self): """Triggered upon the round completion, it registers the last round execution info, @@ -507,18 +563,25 @@ def round_completion_handler(self): """ self.global_virtual_clock += self.round_duration self.round += 1 - last_round_avg_util = sum(self.stats_util_accumulator) / max(1, len(self.stats_util_accumulator)) + last_round_avg_util = sum(self.stats_util_accumulator) / max( + 1, len(self.stats_util_accumulator) + ) # assign avg reward to explored, but not ran workers for client_id in self.round_stragglers: - self.client_manager.register_feedback(client_id, last_round_avg_util, - time_stamp=self.round, - duration=self.virtual_client_clock[client_id]['computation'] + - self.virtual_client_clock[client_id]['communication'], - success=False) + self.client_manager.register_feedback( + client_id, + last_round_avg_util, + time_stamp=self.round, + duration=self.virtual_client_clock[client_id]["computation"] + + self.virtual_client_clock[client_id]["communication"], + success=False, + ) avg_loss = sum(self.loss_accumulator) / max(1, len(self.loss_accumulator)) - logging.info(f"Wall clock: {round(self.global_virtual_clock)} s, round: {self.round}, Planned participants: " + - f"{len(self.sampled_participants)}, Succeed participants: {len(self.stats_util_accumulator)}, Training loss: {avg_loss}") + logging.info( + f"Wall clock: {round(self.global_virtual_clock)} s, round: {self.round}, Planned participants: " + + f"{len(self.sampled_participants)}, Succeed participants: {len(self.stats_util_accumulator)}, Training loss: {avg_loss}" + ) # dump round completion information to tensorboard if len(self.loss_accumulator): @@ -526,10 +589,18 @@ def round_completion_handler(self): # update select participants self.sampled_participants = self.select_participants( - select_num_participants=self.args.num_participants, overcommitment=self.args.overcommitment) - (clients_to_run, round_stragglers, virtual_client_clock, round_duration, - flatten_client_duration) = self.tictak_client_tasks( - self.sampled_participants, self.args.num_participants) + select_num_participants=self.args.num_participants, + overcommitment=self.args.overcommitment, + ) + ( + clients_to_run, + round_stragglers, + virtual_client_clock, + round_duration, + flatten_client_duration, + ) = self.tictak_client_tasks( + self.sampled_participants, self.args.num_participants + ) logging.info(f"Selected participants to run: {clients_to_run}") @@ -539,11 +610,9 @@ def round_completion_handler(self): # Update executors and participants if self.experiment_mode == commons.SIMULATION_MODE: - self.sampled_executors = list( - self.individual_client_events.keys()) + self.sampled_executors = list(self.individual_client_events.keys()) else: - self.sampled_executors = [str(c_id) - for c_id in self.sampled_participants] + self.sampled_executors = [str(c_id) for c_id in self.sampled_participants] self.round_stragglers = round_stragglers self.virtual_client_clock = virtual_client_clock self.flatten_client_duration = np.array(flatten_client_duration) @@ -565,45 +634,61 @@ def round_completion_handler(self): self.broadcast_aggregator_events(commons.START_ROUND) def log_train_result(self, avg_loss): - """Log training result on TensorBoard and optionally WanDB - """ - self.log_writer.add_scalar('Train/round_to_loss', avg_loss, self.round) + """Log training result on TensorBoard and optionally WanDB""" + self.log_writer.add_scalar("Train/round_to_loss", avg_loss, self.round) self.log_writer.add_scalar( - 'Train/time_to_train_loss (min)', avg_loss, self.global_virtual_clock / 60.) + "Train/time_to_train_loss (min)", avg_loss, self.global_virtual_clock / 60.0 + ) self.log_writer.add_scalar( - 'Train/round_duration (min)', self.round_duration / 60., self.round) + "Train/round_duration (min)", self.round_duration / 60.0, self.round + ) self.log_writer.add_histogram( - 'Train/client_duration (min)', self.flatten_client_duration, self.round) + "Train/client_duration (min)", self.flatten_client_duration, self.round + ) if self.wandb != None: - self.wandb.log({ - 'Train/round_to_loss': avg_loss, - 'Train/round_duration (min)': self.round_duration/60., - 'Train/client_duration (min)': self.flatten_client_duration, - 'Train/time_to_round (min)': self.global_virtual_clock/60., - }, step=self.round) - + self.wandb.log( + { + "Train/round_to_loss": avg_loss, + "Train/round_duration (min)": self.round_duration / 60.0, + "Train/client_duration (min)": self.flatten_client_duration, + "Train/time_to_round (min)": self.global_virtual_clock / 60.0, + }, + step=self.round, + ) + def log_test_result(self): - """Log testing result on TensorBoard and optionally WanDB - """ + """Log testing result on TensorBoard and optionally WanDB""" self.log_writer.add_scalar( - 'Test/round_to_loss', self.testing_history['perf'][self.round]['loss'], self.round) + "Test/round_to_loss", + self.testing_history["perf"][self.round]["loss"], + self.round, + ) self.log_writer.add_scalar( - 'Test/round_to_accuracy', self.testing_history['perf'][self.round]['top_1'], self.round) - self.log_writer.add_scalar('Test/time_to_test_loss (min)', self.testing_history['perf'][self.round]['loss'], - self.global_virtual_clock / 60.) - self.log_writer.add_scalar('Test/time_to_test_accuracy (min)', self.testing_history['perf'][self.round]['top_1'], - self.global_virtual_clock / 60.) + "Test/round_to_accuracy", + self.testing_history["perf"][self.round]["top_1"], + self.round, + ) + self.log_writer.add_scalar( + "Test/time_to_test_loss (min)", + self.testing_history["perf"][self.round]["loss"], + self.global_virtual_clock / 60.0, + ) + self.log_writer.add_scalar( + "Test/time_to_test_accuracy (min)", + self.testing_history["perf"][self.round]["top_1"], + self.global_virtual_clock / 60.0, + ) def save_model(self): - """Save model to the wandb server if enabled - - """ + """Save model to the wandb server if enabled""" if parser.args.save_checkpoint and self.last_saved_round < self.round: self.last_saved_round = self.round np.save(self.temp_model_path, self.model_weights) if self.wandb != None: - artifact = self.wandb.Artifact(name='model_'+str(self.this_rank), type='model') + artifact = self.wandb.Artifact( + name="model_" + str(self.this_rank), type="model" + ) artifact.add_file(local_path=self.temp_model_path) self.wandb.log_artifact(artifact) @@ -619,7 +704,7 @@ def deserialize_response(self, responses): return pickle.loads(responses) def serialize_response(self, responses): - """ Serialize the response to send to server upon assigned job completion + """Serialize the response to send to server upon assigned job completion Args: responses (ServerResponse): Serialized response from server. @@ -639,7 +724,7 @@ def testing_completion_handler(self, client_id, results): """ - results = results['results'] + results = results["results"] # List append is thread-safe self.test_result_accumulator.append(results) @@ -647,10 +732,9 @@ def testing_completion_handler(self, client_id, results): # Have collected all testing results if len(self.test_result_accumulator) == len(self.executors): - self.aggregate_test_result() # Dump the testing result - with open(os.path.join(logger.logDir, 'testing_perf'), 'wb') as fout: + with open(os.path.join(logger.logDir, "testing_perf"), "wb") as fout: pickle.dump(self.testing_history, fout) self.save_model() @@ -697,7 +781,7 @@ def get_client_conf(self, client_id): """ conf = { - 'learning_rate': self.args.learning_rate, + "learning_rate": self.args.learning_rate, } return conf @@ -716,7 +800,7 @@ def create_client_task(self, executor_id): # NOTE: model = None then the executor will load the global model broadcasted in UPDATE_MODEL if next_client_id is not None: config = self.get_client_conf(next_client_id) - train_config = {'client_id': next_client_id, 'task_config': config} + train_config = {"client_id": next_client_id, "task_config": config} return train_config, self.model_wrapper.get_weights() def get_test_config(self, client_id): @@ -729,7 +813,7 @@ def get_test_config(self, client_id): dictionary: The testing config for new task. """ - return {'client_id': client_id} + return {"client_id": client_id}, self.model_wrapper.get_weights() def get_shutdown_config(self, client_id): """Shutdown config for client, developers can further define personalized client config here. @@ -741,10 +825,10 @@ def get_shutdown_config(self, client_id): dictionary: Shutdown config for new task. """ - return {'client_id': client_id} + return {"client_id": client_id} def add_event_handler(self, client_id, event, meta, data): - """ Due to the large volume of requests, we will put all events into a queue first. + """Due to the large volume of requests, we will put all events into a queue first. Args: client_id (int): The client id. @@ -780,8 +864,9 @@ def CLIENT_REGISTER(self, request, context): self.executor_info_handler(executor_id, executor_info) dummy_data = self.serialize_response(commons.DUMMY_RESPONSE) - return job_api_pb2.ServerResponse(event=commons.DUMMY_EVENT, - meta=dummy_data, data=dummy_data) + return job_api_pb2.ServerResponse( + event=commons.DUMMY_EVENT, meta=dummy_data, data=dummy_data + ) def CLIENT_PING(self, request, context): """Handle client ping requests @@ -805,25 +890,27 @@ def CLIENT_PING(self, request, context): else: current_event = self.individual_client_events[executor_id].popleft() if current_event == commons.CLIENT_TRAIN: - response_msg, response_data = self.create_client_task( - executor_id) + response_msg, response_data = self.create_client_task(executor_id) if response_msg is None: current_event = commons.DUMMY_EVENT if self.experiment_mode != commons.SIMULATION_MODE: self.individual_client_events[executor_id].append( - commons.CLIENT_TRAIN) + commons.CLIENT_TRAIN + ) elif current_event == commons.MODEL_TEST: - response_msg = self.get_test_config(client_id) + response_msg, response_data = self.get_test_config(client_id) elif current_event == commons.UPDATE_MODEL: response_data = self.model_wrapper.get_weights() elif current_event == commons.SHUT_DOWN: response_msg = self.get_shutdown_config(executor_id) response_msg, response_data = self.serialize_response( - response_msg), self.serialize_response(response_data) + response_msg + ), self.serialize_response(response_data) # NOTE: in simulation mode, response data is pickle for faster (de)serialization - response = job_api_pb2.ServerResponse(event=current_event, - meta=response_msg, data=response_data) + response = job_api_pb2.ServerResponse( + event=current_event, meta=response_msg, data=response_data + ) if current_event != commons.DUMMY_EVENT: logging.info(f"Issue EVENT ({current_event}) to EXECUTOR ({executor_id})") @@ -840,7 +927,11 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): """ - executor_id, client_id, event = request.executor_id, request.client_id, request.event + executor_id, client_id, event = ( + request.executor_id, + request.client_id, + request.event, + ) execution_status, execution_msg = request.status, request.msg meta_result, data_result = request.meta_result, request.data_result @@ -848,27 +939,31 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). if execution_status is False: - logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") + logging.error( + f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}" + ) # TODO: whether we should schedule tasks when client_ping or client_complete if self.resource_manager.has_next_task(executor_id): # NOTE: we do not pop the train immediately in simulation mode, # since the executor may run multiple clients - if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + if ( + commons.CLIENT_TRAIN + not in self.individual_client_events[executor_id] + ): self.individual_client_events[executor_id].append( - commons.CLIENT_TRAIN) + commons.CLIENT_TRAIN + ) elif event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): - self.add_event_handler( - executor_id, event, meta_result, data_result) + self.add_event_handler(executor_id, event, meta_result, data_result) else: logging.error(f"Received undefined event {event} from client {client_id}") return self.CLIENT_PING(request, context) def event_monitor(self): - """Activate event handler according to the received new message - """ + """Activate event handler according to the received new message""" logging.info("Start monitoring events ...") while True: @@ -880,7 +975,6 @@ def event_monitor(self): self.dispatch_client_events(current_event) elif current_event == commons.START_ROUND: - self.dispatch_client_events(commons.CLIENT_TRAIN) elif current_event == commons.SHUT_DOWN: @@ -889,17 +983,22 @@ def event_monitor(self): # Handle events queued on the aggregator elif len(self.server_events_queue) > 0: - client_id, current_event, meta, data = self.server_events_queue.popleft() + ( + client_id, + current_event, + meta, + data, + ) = self.server_events_queue.popleft() if current_event == commons.UPLOAD_MODEL: - self.client_completion_handler( - self.deserialize_response(data)) + self.client_completion_handler(self.deserialize_response(data)) if len(self.stats_util_accumulator) == self.tasks_round: self.round_completion_handler() elif current_event == commons.MODEL_TEST: self.testing_completion_handler( - client_id, self.deserialize_response(data)) + client_id, self.deserialize_response(data) + ) else: logging.error(f"Event {current_event} is not defined") @@ -909,8 +1008,7 @@ def event_monitor(self): time.sleep(0.1) def stop(self): - """Stop the aggregator - """ + """Stop the aggregator""" logging.info(f"Terminating the aggregator ...") if self.wandb != None: self.wandb.finish() diff --git a/fedscale/cloud/execution/executor.py b/fedscale/cloud/execution/executor.py index 7a8b9d6c..072cb7d9 100755 --- a/fedscale/cloud/execution/executor.py +++ b/fedscale/cloud/execution/executor.py @@ -234,7 +234,7 @@ def Test(self, config): config (dictionary): The client testing config. """ - test_res = self.testing_handler() + test_res = self.testing_handler(model=config["model"]) test_res = {"executorId": self.this_rank, "results": test_res} # Report execution completion information @@ -332,7 +332,7 @@ def training_handler(self, client_id, conf, model): return train_res - def testing_handler(self): + def testing_handler(self, model): """Test model Args: @@ -342,6 +342,7 @@ def testing_handler(self): dictionary: The test result """ + self.model_adapter.set_weights(model, is_aggregator=False) test_config = self.override_conf( { "rank": self.this_rank, @@ -360,7 +361,7 @@ def testing_handler(self): ) test_results = client.test( - data_loader, self.model_adapter.get_model(), test_config + data_loader, model=self.model_adapter.get_model(), conf=test_config ) self.log_test_result(test_results) gc.collect() @@ -434,7 +435,11 @@ def event_monitor(self): ) elif current_event == commons.MODEL_TEST: - self.Test(self.deserialize_response(request.meta)) + test_config = self.deserialize_response(request.meta) + test_model = self.deserialize_response(request.data) + test_config["model"] = test_model + test_config["client_id"] = int(test_config["client_id"]) + self.Test(test_config) elif current_event == commons.UPDATE_MODEL: model_weights = self.deserialize_response(request.data) From f5d487bac402a1e6b060d38e083fa41031980f5f Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 21:44:06 -0500 Subject: [PATCH 12/17] Fix yogi init value --- fedscale/cloud/aggregation/optimizers.py | 4 +-- fedscale/utils/optimizer/yogi.py | 31 ++++++++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/fedscale/cloud/aggregation/optimizers.py b/fedscale/cloud/aggregation/optimizers.py index eaceb5e3..201ffccc 100644 --- a/fedscale/cloud/aggregation/optimizers.py +++ b/fedscale/cloud/aggregation/optimizers.py @@ -46,8 +46,8 @@ def update_round_gradient( Sashank J. Reddi, Zachary Charles, Manzil Zaheer, Zachary Garrett, Keith Rush, Jakub Konecný, Sanjiv Kumar, H. Brendan McMahan, ICLR 2021. """ - # last_model = [x.to(device=self.device) for x in last_model] - # current_model = [x.to(device=self.device) for x in current_model] + last_model = [x.to(device=self.device) for x in last_model] + current_model = [x.to(device=self.device) for x in current_model] diff_weight = self.gradient_controller.update( [pb - pa for pa, pb in zip(last_model, current_model)] diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index b5d4f45f..826d15ef 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -1,5 +1,5 @@ import torch - +import numpy as np class YoGi(): def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): @@ -8,28 +8,27 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): self.beta = beta self.v_t = [] - self.delta_t = [] + self.m_t = [] self.beta2 = beta2 def update(self, gradients): update_gradients = [] + if not self.v_t: + self.v_t = [ np.asarray([self.tau**2 for _ in g], dtype=np.float32) for g in gradients] + self.m_t = [ np.full(len(g), 0.0, dtype=np.float32) for g in gradients] + for idx, gradient in enumerate(gradients): gradient_square = gradient**2 - if len(self.v_t) <= idx: - #gradient_square = gradient**2 - self.v_t.append(gradient_square) - self.delta_t.append(gradient) - else: - # yogi - self.delta_t[idx] = self.beta * \ - self.delta_t[idx] + (1.-self.beta) * gradient - #gradient_square = self.delta_t[idx]**2 - self.v_t[idx] = self.v_t[idx] - (1.-self.beta2) * gradient_square * torch.sign( - self.v_t[idx] - gradient_square) - yogi_learning_rate = self.eta / \ - (torch.sqrt(self.v_t[idx]) + self.tau) + + self.m_t[idx] = self.beta * \ + self.m_t[idx] + (1.-self.beta) * gradient + + self.v_t[idx] = self.v_t[idx] - (1.-self.beta2) * gradient_square * torch.sign( + self.v_t[idx] - gradient_square) + yogi_learning_rate = self.eta / \ + (torch.sqrt(self.v_t[idx]) + self.tau) - update_gradients.append(yogi_learning_rate * self.delta_t[idx]) + update_gradients.append(yogi_learning_rate * self.m_t[idx]) if len(update_gradients) == 0: update_gradients = gradients From 5ee6d2a829a5c53814dafd3568100807916a0c72 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 21:47:45 -0500 Subject: [PATCH 13/17] Fix yogi init value --- fedscale/utils/optimizer/yogi.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index 826d15ef..4a0dc4b6 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -1,7 +1,8 @@ import torch import numpy as np -class YoGi(): + +class YoGi: def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): self.eta = eta self.tau = tau @@ -14,19 +15,18 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): def update(self, gradients): update_gradients = [] if not self.v_t: - self.v_t = [ np.asarray([self.tau**2 for _ in g], dtype=np.float32) for g in gradients] - self.m_t = [ np.full(len(g), 0.0, dtype=np.float32) for g in gradients] + self.v_t = [np.full(len(g), self.tau**2, dtype=np.float32) for g in gradients] + self.m_t = [np.full(len(g), 0.0, dtype=np.float32) for g in gradients] for idx, gradient in enumerate(gradients): gradient_square = gradient**2 - - self.m_t[idx] = self.beta * \ - self.m_t[idx] + (1.-self.beta) * gradient - - self.v_t[idx] = self.v_t[idx] - (1.-self.beta2) * gradient_square * torch.sign( - self.v_t[idx] - gradient_square) - yogi_learning_rate = self.eta / \ - (torch.sqrt(self.v_t[idx]) + self.tau) + + self.m_t[idx] = self.beta * self.m_t[idx] + (1.0 - self.beta) * gradient + + self.v_t[idx] = self.v_t[idx] - ( + 1.0 - self.beta2 + ) * gradient_square * torch.sign(self.v_t[idx] - gradient_square) + yogi_learning_rate = self.eta / (torch.sqrt(self.v_t[idx]) + self.tau) update_gradients.append(yogi_learning_rate * self.m_t[idx]) From d897370e84479bfec99e8762474101e78186fa8b Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 21:58:53 -0500 Subject: [PATCH 14/17] Fix init bug --- fedscale/utils/optimizer/yogi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index 4a0dc4b6..2fcde2a5 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -15,8 +15,8 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): def update(self, gradients): update_gradients = [] if not self.v_t: - self.v_t = [np.full(len(g), self.tau**2, dtype=np.float32) for g in gradients] - self.m_t = [np.full(len(g), 0.0, dtype=np.float32) for g in gradients] + self.v_t = [torch.full_like(g, self.tau**2) for g in gradients] + self.m_t = [torch.full_like(g, 0.0) for g in gradients] for idx, gradient in enumerate(gradients): gradient_square = gradient**2 From 442c85cf4d5fa726b11f3bfd852f96f4cdc0f1c1 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 22:54:03 -0500 Subject: [PATCH 15/17] Fix init bug --- fedscale/utils/optimizer/yogi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index 2fcde2a5..0302b974 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -15,7 +15,7 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): def update(self, gradients): update_gradients = [] if not self.v_t: - self.v_t = [torch.full_like(g, self.tau**2) for g in gradients] + self.v_t = [g**2 for g in gradients] self.m_t = [torch.full_like(g, 0.0) for g in gradients] for idx, gradient in enumerate(gradients): From d889ac86faeeed5dccfe02982de02a71d375c820 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sat, 16 Dec 2023 22:57:18 -0500 Subject: [PATCH 16/17] Fix init bug --- fedscale/utils/optimizer/yogi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fedscale/utils/optimizer/yogi.py b/fedscale/utils/optimizer/yogi.py index 0302b974..2d73d0ad 100755 --- a/fedscale/utils/optimizer/yogi.py +++ b/fedscale/utils/optimizer/yogi.py @@ -15,7 +15,7 @@ def __init__(self, eta=1e-2, tau=1e-3, beta=0.9, beta2=0.99): def update(self, gradients): update_gradients = [] if not self.v_t: - self.v_t = [g**2 for g in gradients] + self.v_t = [torch.full_like(g, self.tau) for g in gradients] self.m_t = [torch.full_like(g, 0.0) for g in gradients] for idx, gradient in enumerate(gradients): From 7918a28f294811ee1409320c1cccd7a6b37c8bb4 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sun, 17 Dec 2023 08:05:52 -0500 Subject: [PATCH 17/17] Validated; change optimizer naming in config --- benchmark/configs/openimage/openimage.yml | 2 +- benchmark/configs/others/local_dp.yml | 2 +- benchmark/configs/others/rl_conf.yml | 2 +- benchmark/configs/speech/google_speech.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmark/configs/openimage/openimage.yml b/benchmark/configs/openimage/openimage.yml index 829d3d93..96feb4e0 100644 --- a/benchmark/configs/openimage/openimage.yml +++ b/benchmark/configs/openimage/openimage.yml @@ -51,7 +51,7 @@ job_conf: - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default + - gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default - eval_interval: 30 # How many rounds to run a testing on the testing set - rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples diff --git a/benchmark/configs/others/local_dp.yml b/benchmark/configs/others/local_dp.yml index 3c4c3747..78dc0c8e 100644 --- a/benchmark/configs/others/local_dp.yml +++ b/benchmark/configs/others/local_dp.yml @@ -38,7 +38,7 @@ job_conf: # - device_avail_file: /users/JIACHEN/FLPerf-Cluster/client_datamap/client_behave_trace - sample_mode: random # Client selection: random, oort, random by default - model: resnet18 # Models: shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - gradient_policy: fedavg # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default + - gradient_policy: fed-avg # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default - eval_interval: 10 # How many rounds to run a testing on the testing set - rounds: 500 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 20 diff --git a/benchmark/configs/others/rl_conf.yml b/benchmark/configs/others/rl_conf.yml index bf8e04aa..b069cb5e 100644 --- a/benchmark/configs/others/rl_conf.yml +++ b/benchmark/configs/others/rl_conf.yml @@ -48,7 +48,7 @@ job_conf: - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - model: dqn # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default + - gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default - eval_interval: 50 # How many rounds to run a testing on the testing set - rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples diff --git a/benchmark/configs/speech/google_speech.yml b/benchmark/configs/speech/google_speech.yml index 8d4696f5..a7875f74 100644 --- a/benchmark/configs/speech/google_speech.yml +++ b/benchmark/configs/speech/google_speech.yml @@ -50,7 +50,7 @@ job_conf: - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - model: resnet34 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default + - gradient_policy: fed-yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default - eval_interval: 30 # How many rounds to run a testing on the testing set - rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples