-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·146 lines (126 loc) · 5.3 KB
/
utils.py
File metadata and controls
executable file
·146 lines (126 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from data_preprocess_and_load.datasets import * #####
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from datetime import datetime
from pytz import timezone
import argparse
import os
import dill
import random
import builtins
import time
import random
def _get_sync_file():
"""Logic for naming sync file using slurm env variables"""
if 'SCRATCH' in os.environ:
sync_file_dir = '%s/pytorch-sync-files' % os.environ['SCRATCH'] # Perlmutter
else:
sync_file_dir = '%s/pytorch-sync-files' % '/lus/grand/projects/STlearn/'
#raise Exception('there is no env variable SCRATCH. Please check sync_file dir')
os.makedirs(sync_file_dir, exist_ok=True)
#temporally add two lines below for torchrun
if ('SLURM_JOB_ID' in os.environ) and ('SLURM_STEP_ID' in os.environ) :
sync_file = 'file://%s/pytorch_sync.%s.%s' % (
sync_file_dir, os.environ['SLURM_JOB_ID'], os.environ['SLURM_STEP_ID'])
else:
sync_file = 'file://%s/pytorch_sync.%s.%s' % (
sync_file_dir, '10004', '10003')
return sync_file
def init_distributed(args):
# torchrun: when WORLD_SIZE is set in the sbatch script (gpus_per_node * num_nodes)
if "WORLD_SIZE" in os.environ: # for torchrun
args.world_size = int(os.environ["WORLD_SIZE"])
#print('args.world_size:',args.world_size)
elif 'SLURM_NTASKS' in os.environ: # for slurm scheduler
args.world_size = int(os.environ['SLURM_NTASKS'])
else:
pass # torch.distributed.launch
args.distributed = args.world_size > 1 # default: world_size = -1
if args.distributed:
start_time = time.time()
#args.local_rank = int(os.environ['LOCAL_RANK']) #stella added this line
if args.local_rank != -1: # for torch.distributed.launch
args.rank = args.local_rank
args.gpu = args.local_rank
elif 'RANK' in os.environ: # for torchrun
args.rank = int(os.environ['RANK'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ: # for slurm scheduler
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
#print('args.rank:',args.rank)
#print('args.gpu:',args.gpu)
if args.init_method == 'file':
sync_file = _get_sync_file()
# print('initializing DDP with sync file')
elif args.init_method == 'env':
sync_file = "env://"
#os.environ['MASTER_PORT'] = '47769'
#os.environ['MASTER_ADDR'] = '127.0.0.1'# os.environ['SLURM_JOB_NODELIST']
#print(os.environ['MASTER_PORT'])
#print(os.environ['MASTER_ADDR'])
#print('initializing DDP with env variables')
dist.init_process_group(backend=args.dist_backend, init_method=sync_file,
world_size=args.world_size, rank=args.rank)
#dist_init_time = time.time() - start_time
#print(f'seconds taken for DDP initialization: {dist_init_time}')
else:
args.rank = 0
args.gpu = 0
def weight_loader(args):
model_weights_path = None
try:
if args.step == '1' :
task = 'vanilla_BERT'
if os.path.exists(args.model_weights_path_ABCD):
model_weights_path = args.model_weights_path_ABCD
elif args.step == '2':
task = 'MBBN'
if os.path.exists(args.model_weights_path_ABCD):
model_weights_path = args.model_weights_path_ABCD
elif args.step == '3' :
task = 'MBBN_reconstruction'
elif args.step == '4':
task = None # test phase (for visualization)
if os.path.exists(args.model_weights_path_phase2):
model_weights_path = args.model_weights_path_phase2
except:
#if no weights were provided
model_weights_path = None
# print(f'loading weight from {model_weights_path}')
return model_weights_path, args.step, task
def datestamp():
time = datetime.now(timezone('Asia/Seoul')).strftime("%m_%d__%H_%M_%S")
return time
def reproducibility(**kwargs):
seed = kwargs.get('seed')
cuda = kwargs.get('cuda')
torch.manual_seed(seed)
# Fix Python's built-in random module seed
random.seed(seed)
if cuda:
torch.cuda.manual_seed(seed)
np.random.seed(seed)
cudnn.deterministic = False #True
cudnn.benchmark = True
def sort_args(phase, args):
phase_specific_args = {}
for name, value in args.items():
if not 'phase' in name:
phase_specific_args[name] = value
elif 'phase' + phase in name:
phase_specific_args[name.replace('_phase' + phase, '')] = value
return phase_specific_args
def args_logger(args):
args_to_pkl(args)
args_to_text(args)
def args_to_pkl(args):
with open(os.path.join(args.experiment_folder,'arguments_as_is.pkl'),'wb') as f:
#f.write(vars(args))
dill.dump(vars(args),f)
def args_to_text(args):
with open(os.path.join(args.experiment_folder,'argument_documentation.txt'),'w+') as f:
for name,arg in vars(args).items():
f.write('{}: {}\n'.format(name,arg))