forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_utils.py
79 lines (64 loc) · 2.74 KB
/
pytorch_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from typing import Tuple
from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
from algoperf import spec
from algoperf.profiler import Profiler
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
BatchNorm as ConformerBatchNorm
from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
BatchNorm as DeepspeechBatchNorm
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
use_pytorch_ddp = 'LOCAL_RANK' in os.environ
rank = int(os.environ['LOCAL_RANK']) if use_pytorch_ddp else 0
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
n_gpus = torch.cuda.device_count()
return use_pytorch_ddp, rank, device, n_gpus
def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Only use CPU for Jax to avoid memory issues.
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
# From the docs: "(...) causes cuDNN to benchmark multiple convolution
# algorithms and select the fastest."
torch.backends.cudnn.benchmark = True
if use_pytorch_ddp:
# Avoid tf input pipeline creating too many threads.
if rank != 0:
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
torch.cuda.set_device(rank)
profiler.set_local_rank(rank)
# Only log once (for local rank == 0).
if rank != 0:
def logging_pass(*args):
pass
logging.info = logging_pass
# Initialize the process group.
dist.init_process_group('nccl')
def sync_ddp_time(time: float, device: torch.device) -> float:
time_tensor = torch.tensor(time, dtype=torch.float64, device=device)
dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX)
return time_tensor.item()
def update_batch_norm_fn(module: spec.ParameterContainer,
update_batch_norm: bool) -> None:
bn_layers = (
torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class.
ConformerBatchNorm, # Custom BN class for conformer model.
DeepspeechBatchNorm, # Custom BN class for deepspeech model.
)
if isinstance(module, bn_layers):
if not update_batch_norm:
if not hasattr(module, 'momentum_backup'):
module.momentum_backup = module.momentum
# module.momentum can be float or torch.Tensor.
if torch.is_tensor(module.momentum_backup):
module.momentum = torch.zeros_like(module.momentum_backup)
else:
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup