From 0d35771629f65a9a06ad7e66dd11bfbe06091971 Mon Sep 17 00:00:00 2001 From: juyongjiang Date: Fri, 10 Sep 2021 19:51:20 +0900 Subject: [PATCH] Implementing multi-GPUs Training for RecBole --- recbole/config/configurator.py | 3 +- recbole/properties/model/BERT4Rec.yaml | 6 ++- recbole/properties/overall.yaml | 2 + recbole/trainer/trainer.py | 66 ++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 2 deletions(-) diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py index 854b42dc1..2b8e9ad1d 100644 --- a/recbole/config/configurator.py +++ b/recbole/config/configurator.py @@ -309,7 +309,8 @@ def _set_default_parameters(self): def _init_device(self): use_gpu = self.final_config_dict['use_gpu'] - if use_gpu: + ###@Juyong Jiang + if use_gpu and not self.final_config_dict['multi_gpus']: os.environ["CUDA_VISIBLE_DEVICES"] = str(self.final_config_dict['gpu_id']) self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu") diff --git a/recbole/properties/model/BERT4Rec.yaml b/recbole/properties/model/BERT4Rec.yaml index f8299038a..252ab1cb5 100644 --- a/recbole/properties/model/BERT4Rec.yaml +++ b/recbole/properties/model/BERT4Rec.yaml @@ -8,4 +8,8 @@ hidden_act: 'gelu' layer_norm_eps: 1e-12 initializer_range: 0.02 mask_ratio: 0.2 -loss_type: 'CE' \ No newline at end of file +loss_type: 'CE' +###@Juyong Jiang +#Note that: `training_neg_sample_num` should be 0 when the `loss_type` is CE +training_neg_sample_num: 0 +multi_gpus: True \ No newline at end of file diff --git a/recbole/properties/overall.yaml b/recbole/properties/overall.yaml index a99b33176..cf5092d3a 100644 --- a/recbole/properties/overall.yaml +++ b/recbole/properties/overall.yaml @@ -21,6 +21,8 @@ clip_grad_norm: ~ # clip_grad_norm: {'max_norm': 5, 'norm_type': 2} weight_decay: 0.0 draw_loss_pic: False +###@Juyong Jiang +multi_gpus: False # evaluation settings eval_setting: RO_RS,full diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index 64544eb0e..f122ced6c 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -33,6 +33,17 @@ DataLoaderType, KGDataLoaderState from recbole.utils.utils import set_color +### +""" + # @Time : 2021/09/10 + # @Author : Juyong Jiang + # @Email : csjuyongjiang@gmail.com +""" +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +### class AbstractTrainer(object): r"""Trainer Class is used to manage the training and evaluation processes of recommender system models. @@ -105,6 +116,13 @@ def __init__(self, config, model): self.item_tensor = None self.tot_item_num = None + ###@Juyong Jiang + #you can turn on or off(None) this setting in your `config.yaml` + self.multi_gpus = config['multi_gpus'] + if torch.cuda.device_count() > 1 and self.multi_gpus: + self._build_distribute(backend="nccl") + print("Let's use", torch.cuda.device_count(), "GPUs to train ", self.config['model'], "...") + def _build_optimizer(self, params): r"""Init the Optimizer @@ -134,6 +152,43 @@ def _build_optimizer(self, params): optimizer = optim.Adam(params, lr=self.learning_rate) return optimizer + ###@Juyong Jiang + def _build_distribute(self, backend): + # 1 set backend + torch.distributed.init_process_group(backend=backend) + # 2 get distributed id + local_rank = torch.distributed.get_rank() + torch.cuda.set_device(local_rank) + device_dis = torch.device("cuda", local_rank) + # 3, 4 assign model to be distributed + self.model.to(device_dis) + self.model = DistributedDataParallel(self.model, + device_ids=[local_rank], + output_device=local_rank).module + return self.model + + def _trans_dataload(self, interaction): + data_dict = {} + #using pytorch dataload to re-wrap dataset + def sub_trans(dataset): + dis_loader = DataLoader(dataset=dataset, + batch_size=dataset.shape[0], + sampler=DistributedSampler(dataset, shuffle=False)) + for data in dis_loader: + batch_data = data + + return batch_data + #change `interaction` datatype to a python `dict` object. + #for some methods, you may need transfer more data unit like the following way. + data_dict[self.config['USER_ID_FIELD']] = sub_trans(interaction[self.config['USER_ID_FIELD']]) + data_dict[self.config['ITEM_ID_FIELD']] = sub_trans(interaction[self.config['ITEM_ID_FIELD']]) + data_dict[self.config['TIME_FIELD']] = sub_trans(interaction[self.config['TIME_FIELD']]) + data_dict[self.config['ITEM_LIST_LENGTH_FIELD']] = sub_trans(interaction[self.config['ITEM_LIST_LENGTH_FIELD']]) + data_dict['item_id_list'] = sub_trans(interaction['item_id_list']) + data_dict['timestamp_list'] = sub_trans(interaction['timestamp_list']) + return data_dict + ### + def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): r"""Train the model in an epoch @@ -161,6 +216,17 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals ) for batch_idx, interaction in iter_data: interaction = interaction.to(self.device) + + ###@Juyong Jiang + #in fact, it costs ignorable time to transfer the dataset. + if torch.cuda.device_count() > 1 and self.multi_gpus: + # import time + # start_ct = time.time() + interaction = self._trans_dataload(interaction) + # end_ct = time.time() + # print('Dataset Converting Time: ', end_ct-start_ct) + ### + self.optimizer.zero_grad() losses = loss_func(interaction) if isinstance(losses, tuple):