Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing multi-GPUs Training for RecBole #961

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def _set_default_parameters(self):

def _init_device(self):
use_gpu = self.final_config_dict['use_gpu']
if use_gpu:
###@Juyong Jiang
if use_gpu and not self.final_config_dict['multi_gpus']:
os.environ["CUDA_VISIBLE_DEVICES"] = str(self.final_config_dict['gpu_id'])
self.final_config_dict['device'] = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")

Expand Down
6 changes: 5 additions & 1 deletion recbole/properties/model/BERT4Rec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ hidden_act: 'gelu'
layer_norm_eps: 1e-12
initializer_range: 0.02
mask_ratio: 0.2
loss_type: 'CE'
loss_type: 'CE'
###@Juyong Jiang
#Note that: `training_neg_sample_num` should be 0 when the `loss_type` is CE
training_neg_sample_num: 0
multi_gpus: True
2 changes: 2 additions & 0 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ clip_grad_norm: ~
# clip_grad_norm: {'max_norm': 5, 'norm_type': 2}
weight_decay: 0.0
draw_loss_pic: False
###@Juyong Jiang
multi_gpus: False

# evaluation settings
eval_setting: RO_RS,full
Expand Down
66 changes: 66 additions & 0 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@
DataLoaderType, KGDataLoaderState
from recbole.utils.utils import set_color

###
"""
# @Time : 2021/09/10
# @Author : Juyong Jiang
# @Email : [email protected]
"""
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
###

class AbstractTrainer(object):
r"""Trainer Class is used to manage the training and evaluation processes of recommender system models.
Expand Down Expand Up @@ -105,6 +116,13 @@ def __init__(self, config, model):
self.item_tensor = None
self.tot_item_num = None

###@Juyong Jiang
#you can turn on or off(None) this setting in your `config.yaml`
self.multi_gpus = config['multi_gpus']
if torch.cuda.device_count() > 1 and self.multi_gpus:
self._build_distribute(backend="nccl")
print("Let's use", torch.cuda.device_count(), "GPUs to train ", self.config['model'], "...")

def _build_optimizer(self, params):
r"""Init the Optimizer

Expand Down Expand Up @@ -134,6 +152,43 @@ def _build_optimizer(self, params):
optimizer = optim.Adam(params, lr=self.learning_rate)
return optimizer

###@Juyong Jiang
def _build_distribute(self, backend):
# 1 set backend
torch.distributed.init_process_group(backend=backend)
# 2 get distributed id
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device_dis = torch.device("cuda", local_rank)
# 3, 4 assign model to be distributed
self.model.to(device_dis)
self.model = DistributedDataParallel(self.model,
device_ids=[local_rank],
output_device=local_rank).module
return self.model

def _trans_dataload(self, interaction):
data_dict = {}
#using pytorch dataload to re-wrap dataset
def sub_trans(dataset):
dis_loader = DataLoader(dataset=dataset,
batch_size=dataset.shape[0],
sampler=DistributedSampler(dataset, shuffle=False))
for data in dis_loader:
batch_data = data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand this loop, it seems batch_data will be the last 'data' of dis_loader, could you please explain it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And did you test your code in some datasets like ml-100k? Could you provide us the performance results of models? I want to know if the model performance will change a lot compared with single-GPU training.

Copy link
Contributor Author

@juyongjiang juyongjiang Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, Xingyu! Yeah, my pleasure! In our DataLoader class, I assign batch_size=dataset.shape[0] that means it extracts all data in current batch_size. So the length of dis_loader will be only one, i.e. like this for data in range(1).

https://github.com/juyongjiang/RecBole/blob/0d35771629f65a9a06ad7e66dd11bfbe06091971/recbole/trainer/trainer.py#L173-L180

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, of course! Please wait for a moment! I will provide a table to illustrate the performance compared with single GPU training.

Copy link
Contributor Author

@juyongjiang juyongjiang Sep 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2017pxy Hi, Xingyu! I have got the experimental results. It seems that it doesn't decrease the performance a lot but significantly reduces the training time by about 3.78 times. BTW, I just run the experiment only one time. So I think this performance drift can be ignored. : )
Note that the original item means I got the result through running your original RecBole code. And multi-GPUs item result is produced by 3 multi-GPUs.
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2017pxy Any further questions and or comments? Thanks in advance.

Copy link
Member

@2017pxy 2017pxy Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @juyongjiang @hunkim, sorry for late reply.

Following your implementation, our team modified the trainer and made some tests. We find your implementation works well for model training. Thanks for your contribution!

However, since the time cost of run_recbole is mainly from model evaluation, we want to implement the multi-GPUs evaluation as well and release together with the multi-GPUs training. Unfortunately, we face some problems when we apply your implementations to evaluation since the data organization for evaluation is different. Thus, I am sorry to tell you that it still takes some time to release this new feature, and even this new feature might not be added in next version.

Thanks again for your implementation, and if you have any idea or suggestions about multi-GPUs evaluation, please let us know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2017pxy Okay, got it! Thanks for your reply. I will implement the multi-GPUs evaluation as well and pull a new request. : )


return batch_data
#change `interaction` datatype to a python `dict` object.
#for some methods, you may need transfer more data unit like the following way.
data_dict[self.config['USER_ID_FIELD']] = sub_trans(interaction[self.config['USER_ID_FIELD']])
data_dict[self.config['ITEM_ID_FIELD']] = sub_trans(interaction[self.config['ITEM_ID_FIELD']])
data_dict[self.config['TIME_FIELD']] = sub_trans(interaction[self.config['TIME_FIELD']])
data_dict[self.config['ITEM_LIST_LENGTH_FIELD']] = sub_trans(interaction[self.config['ITEM_LIST_LENGTH_FIELD']])
data_dict['item_id_list'] = sub_trans(interaction['item_id_list'])
data_dict['timestamp_list'] = sub_trans(interaction['timestamp_list'])
return data_dict
###

def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
r"""Train the model in an epoch

Expand Down Expand Up @@ -161,6 +216,17 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals
)
for batch_idx, interaction in iter_data:
interaction = interaction.to(self.device)

###@Juyong Jiang
#in fact, it costs ignorable time to transfer the dataset.
if torch.cuda.device_count() > 1 and self.multi_gpus:
# import time
# start_ct = time.time()
interaction = self._trans_dataload(interaction)
# end_ct = time.time()
# print('Dataset Converting Time: ', end_ct-start_ct)
###

self.optimizer.zero_grad()
losses = loss_func(interaction)
if isinstance(losses, tuple):
Expand Down