Skip to content

Commit

Permalink
fix nccl support
Browse files Browse the repository at this point in the history
  • Loading branch information
yzy-thu committed May 27, 2022
1 parent 7a8420f commit 8a46903
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions SwissArmyTransformer/training/deepspeed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,11 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,

loss_checker = lm_loss_reduced
for name in metrics:
metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data)
metrics[name].data /= args.world_size
loss_checker = loss_checker + metrics[name]
if not 'eval' in name:
metrics[name] = metrics[name].detach().clone()
torch.distributed.all_reduce(metrics[name].data)
metrics[name].data /= args.world_size
loss_checker = loss_checker + metrics[name]
if loss_checker.isnan().any() or loss_checker.isinf().any():
print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
return lm_loss.detach(), 1, metrics
Expand Down Expand Up @@ -460,25 +461,28 @@ def evaluate(data_iterator, model, eval_iters, args, timers, split, verbose=Fals
if args.deepspeed and args.deepspeed_activation_checkpointing:
deepspeed.checkpointing.reset()
total_lm_loss += lm_loss.data.detach().float().item()
is_last = True if iteration == eval_iters and args.strict_eval else False
is_last = True if iteration == eval_iters and args.strict_eval and len(last_shape)>0 else False
for name in metrics:
if name not in metrics_total:
metrics_total[name] = []
is_scalar[name] = True if len(metrics[name].shape)==0 else False
shape = list(metrics[name].shape)
if rank==0:
metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
else:
metrics_gathered = None
if not is_scalar[name] and is_last and metrics[name].shape[0] != last_shape[0]:
# pad tensor's first dim to args.batch_size
metrics[name] = torch.concat([metrics[name], torch.zeros([last_shape[0]-metrics[name].shape[0]] + shape[1:], dtype=metrics[name].dtype, device=metrics[name].device)])
torch.distributed.gather(metrics[name], metrics_gathered, 0)
if rank==0:
metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
else:
# metrics_gathered = None
metrics_gathered = [torch.zeros_like(metrics[name], dtype=metrics[name].dtype, device=metrics[name].device) for _ in range(args.world_size)]
# torch.distributed.gather(metrics[name], metrics_gathered, 0)
torch.distributed.all_gather(metrics_gathered, metrics[name])

if rank==0:
gathered_len = len(metrics_gathered) if not is_last else len(metrics_gathered) - drop_number
for i in range(gathered_len):
if is_scalar[name] or not is_last:
metrics_total[name].append(metrics_gathered[i].data.cpu())
metrics_total[name].append(metrics_gathered[i].data.cpu())
else:
metrics_total[name].append(metrics_gathered[i][:last_shape[i]].data.cpu())
# Move model back to the train mode.
Expand Down

0 comments on commit 8a46903

Please sign in to comment.