Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Merge pull request #145 from csukuangfj/fangjun-ddp
Browse files Browse the repository at this point in the history
Support multi-GPU training with DDP from PyTorch.
  • Loading branch information
danpovey authored Apr 12, 2021
2 parents 835e59e + beb4f54 commit db507b4
Showing 1 changed file with 69 additions and 14 deletions.
83 changes: 69 additions & 14 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import os
import sys
import torch
import torch.multiprocessing as mp
from datetime import datetime
from pathlib import Path
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Optional, Tuple
Expand All @@ -31,6 +33,8 @@
from snowfall.common import load_checkpoint, save_checkpoint
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.dist import cleanup_dist
from snowfall.dist import setup_dist
from snowfall.models import AcousticModel
from snowfall.models.conformer import Conformer
from snowfall.models.transformer import Noam, Transformer
Expand Down Expand Up @@ -275,7 +279,7 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader,
time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds()

if forward_count == 1 or accum_grad == 1:
P.set_scores_stochastic_(model.P_scores)
P.set_scores_stochastic_(model.module.P_scores)
assert P.requires_grad is True

curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf(
Expand Down Expand Up @@ -342,13 +346,18 @@ def train_one_epoch(dataloader: torch.utils.data.DataLoader,
tb_writer.add_scalar('train/global_valid_average_objf',
valid_average_objf,
global_batch_idx_train)
model.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train)
model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train)
prev_timestamp = datetime.now()
return total_objf / total_frames, valid_average_objf, global_batch_idx_train


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
'--world-size',
type=int,
default=1,
help='Number of GPUs for DDP training.')
parser.add_argument(
'--model-type',
type=str,
Expand All @@ -359,7 +368,7 @@ def get_parser():
'--num-epochs',
type=int,
default=10,
help="Number of traning epochs.")
help="Number of training epochs.")
parser.add_argument(
'--start-epoch',
type=int,
Expand Down Expand Up @@ -452,9 +461,18 @@ def get_parser():
return parser


def main():
args = get_parser().parse_args()

def run(rank, world_size, args):
'''
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
'''
model_type = args.model_type
start_epoch = args.start_epoch
num_epochs = args.num_epochs
Expand All @@ -464,10 +482,15 @@ def main():
att_rate = args.att_rate

fix_random_seed(42)
setup_dist(rank, world_size)

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
setup_logger('{}/log/log-train'.format(exp_dir))
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard else None
setup_logger(f'{exp_dir}/log/log-train-{rank}')
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
else:
tb_writer = None
# tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

# load L, G, symbol_table
lang_dir = Path('data/lang_nosp')
Expand All @@ -481,9 +504,10 @@ def main():
with open(lang_dir / 'L.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
L_inv = k2.arc_sort(L.invert_())
torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')
if rank == 0:
torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt')

device_id = 0
device_id = rank
device = torch.device('cuda', device_id)

graph_compiler = MmiTrainingGraphCompiler(
Expand Down Expand Up @@ -626,6 +650,9 @@ def main():
model.to(device)
describe(model)

model = DDP(model, device_ids=[rank])


optimizer = Noam(model.parameters(),
model_size=args.attention_dim,
factor=1.0,
Expand Down Expand Up @@ -683,7 +710,8 @@ def main():
learning_rate=curr_learning_rate,
objf=objf,
valid_objf=valid_objf,
global_batch_idx_train=global_batch_idx_train)
global_batch_idx_train=global_batch_idx_train,
local_rank=rank)
save_training_info(filename=best_epoch_info_filename,
model_path=best_model_path,
current_epoch=epoch,
Expand All @@ -692,7 +720,8 @@ def main():
best_objf=best_objf,
valid_objf=valid_objf,
best_valid_objf=best_valid_objf,
best_epoch=best_epoch)
best_epoch=best_epoch,
local_rank=rank)

# we always save the model for every epoch
model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
Expand All @@ -704,7 +733,8 @@ def main():
learning_rate=curr_learning_rate,
objf=objf,
valid_objf=valid_objf,
global_batch_idx_train=global_batch_idx_train)
global_batch_idx_train=global_batch_idx_train,
local_rank=rank)
epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch))
save_training_info(filename=epoch_info_filename,
model_path=model_path,
Expand All @@ -714,9 +744,34 @@ def main():
best_objf=best_objf,
valid_objf=valid_objf,
best_valid_objf=best_valid_objf,
best_epoch=best_epoch)
best_epoch=best_epoch,
local_rank=rank)

logging.warning('Done')
torch.distributed.barrier()
# NOTE: The training process is very likely to hang at this point.
# If you press ctrl + c, your GPU memory will not be freed.
# To free you GPU memory, you can run:
#
# $ ps aux | grep multi
#
# And it will print something like below:
#
# kuangfa+ 430518 98.9 0.6 57074236 3425732 pts/21 Rl Apr02 639:01 /root/fangjun/py38/bin/python3 -c from multiprocessing.spawn
#
# You can kill the process manually by:
#
# $ kill -9 430518
#
# And you will see that your GPU is now not occupied anymore.
cleanup_dist()


def main():
args = get_parser().parse_args()
world_size = args.world_size
assert world_size >= 1
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)


torch.set_num_threads(1)
Expand Down

0 comments on commit db507b4

Please sign in to comment.