Skip to content

Commit 0bbc160

Browse files
committed
keep only last ckpt + black + fixes multi gpu tqdm
1 parent dc4cde3 commit 0bbc160

File tree

2 files changed

+76
-162
lines changed

2 files changed

+76
-162
lines changed

train/utils.py

+33-61
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ def go(model, bkey):
3131
try:
3232
new_state_dict[k] = saved_state_dict[k]
3333
if saved_state_dict[k].shape != state_dict[k].shape:
34-
print(
35-
"shape-%s-mismatch|need-%s|get-%s"
36-
% (k, state_dict[k].shape, saved_state_dict[k].shape)
37-
) #
34+
print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) #
3835
raise KeyError
3936
except:
4037
# logger.info(traceback.format_exc())
@@ -52,9 +49,7 @@ def go(model, bkey):
5249

5350
iteration = checkpoint_dict["iteration"]
5451
learning_rate = checkpoint_dict["learning_rate"]
55-
if (
56-
optimizer is not None and load_opt == 1
57-
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
52+
if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
5853
# try:
5954
optimizer.load_state_dict(checkpoint_dict["optimizer"])
6055
# except:
@@ -106,10 +101,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
106101
try:
107102
new_state_dict[k] = saved_state_dict[k]
108103
if saved_state_dict[k].shape != state_dict[k].shape:
109-
print(
110-
"shape-%s-mismatch|need-%s|get-%s"
111-
% (k, state_dict[k].shape, saved_state_dict[k].shape)
112-
) #
104+
print("shape-%s-mismatch|need-%s|get-%s" % (k, state_dict[k].shape, saved_state_dict[k].shape)) #
113105
raise KeyError
114106
except:
115107
# logger.info(traceback.format_exc())
@@ -123,9 +115,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
123115

124116
iteration = checkpoint_dict["iteration"]
125117
learning_rate = checkpoint_dict["learning_rate"]
126-
if (
127-
optimizer is not None and load_opt == 1
128-
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
118+
if optimizer is not None and load_opt == 1: ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
129119
# try:
130120
optimizer.load_state_dict(checkpoint_dict["optimizer"])
131121
# except:
@@ -134,33 +124,39 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
134124
return model, optimizer, learning_rate, iteration
135125

136126

137-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
138-
logger.info(
139-
"Saving model and optimizer state at epoch {} to {}".format(
140-
iteration, checkpoint_path
141-
)
142-
)
127+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path, checkpoint_type, delete_old=False):
128+
# logger.info(
129+
# "Saving model and optimizer state at epoch {} to {}".format(
130+
# iteration, checkpoint_path
131+
# )
132+
# )
143133
if hasattr(model, "module"):
144134
state_dict = model.module.state_dict()
145135
else:
146136
state_dict = model.state_dict()
137+
if delete_old:
138+
latest_checkpoint = latest_checkpoint_path(checkpoint_path, regex=("G_*.pth" if checkpoint_type.startswith("G") else "D_*.pth"))
139+
147140
torch.save(
148141
{
149142
"model": state_dict,
150143
"iteration": iteration,
151144
"optimizer": optimizer.state_dict(),
152145
"learning_rate": learning_rate,
153146
},
154-
checkpoint_path,
147+
os.path.join(checkpoint_path, checkpoint_type),
155148
)
149+
# delete after saving new checkpoint to avoid loss if save fails
150+
if delete_old and latest_checkpoint is not None:
151+
os.remove(latest_checkpoint)
156152

157153

158154
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
159-
logger.info(
160-
"Saving model and optimizer state at epoch {} to {}".format(
161-
iteration, checkpoint_path
162-
)
163-
)
155+
# logger.info(
156+
# "Saving model and optimizer state at epoch {} to {}".format(
157+
# iteration, checkpoint_path
158+
# )
159+
164160
if hasattr(combd, "module"):
165161
state_dict_combd = combd.module.state_dict()
166162
else:
@@ -204,7 +200,7 @@ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204200
f_list = glob.glob(os.path.join(dir_path, regex))
205201
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206202
x = f_list[-1]
207-
print(x)
203+
# print(x)
208204
return x
209205

210206

@@ -247,9 +243,7 @@ def plot_alignment_to_numpy(alignment, info=None):
247243
import numpy as np
248244

249245
fig, ax = plt.subplots(figsize=(6, 4))
250-
im = ax.imshow(
251-
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
252-
)
246+
im = ax.imshow(alignment.transpose(), aspect="auto", origin="lower", interpolation="none")
253247
fig.colorbar(im, ax=ax)
254248
xlabel = "Decoder timestep"
255249
if info is not None:
@@ -302,35 +296,21 @@ def get_hparams(init=True):
302296
required=True,
303297
help="checkpoint save frequency (epoch)",
304298
)
305-
parser.add_argument(
306-
"-te", "--total_epoch", type=int, required=True, help="total_epoch"
307-
)
308-
parser.add_argument(
309-
"-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
310-
)
311-
parser.add_argument(
312-
"-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
313-
)
299+
parser.add_argument("-te", "--total_epoch", type=int, required=True, help="total_epoch")
300+
parser.add_argument("-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path")
301+
parser.add_argument("-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path")
314302
parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
315-
parser.add_argument(
316-
"-bs", "--batch_size", type=int, required=True, help="batch size"
317-
)
318-
parser.add_argument(
319-
"-e", "--experiment_dir", type=str, required=True, help="experiment dir"
320-
) # -m
321-
parser.add_argument(
322-
"-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323-
)
303+
parser.add_argument("-bs", "--batch_size", type=int, required=True, help="batch size")
304+
parser.add_argument("-e", "--experiment_dir", type=str, required=True, help="experiment dir") # -m
305+
parser.add_argument("-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k")
324306
parser.add_argument(
325307
"-sw",
326308
"--save_every_weights",
327309
type=str,
328310
default="0",
329311
help="save the extracted model in weights directory when saving checkpoints",
330312
)
331-
parser.add_argument(
332-
"-v", "--version", type=str, required=True, help="model version"
333-
)
313+
parser.add_argument("-v", "--version", type=str, required=True, help="model version")
334314
parser.add_argument(
335315
"-f0",
336316
"--if_f0",
@@ -414,11 +394,7 @@ def get_hparams_from_file(config_path):
414394
def check_git_hash(model_dir):
415395
source_dir = os.path.dirname(os.path.realpath(__file__))
416396
if not os.path.exists(os.path.join(source_dir, ".git")):
417-
logger.warn(
418-
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
419-
source_dir
420-
)
421-
)
397+
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(source_dir))
422398
return
423399

424400
cur_hash = subprocess.getoutput("git rev-parse HEAD")
@@ -427,11 +403,7 @@ def check_git_hash(model_dir):
427403
if os.path.exists(path):
428404
saved_hash = open(path).read()
429405
if saved_hash != cur_hash:
430-
logger.warn(
431-
"git hash values are different. {}(saved) != {}(current)".format(
432-
saved_hash[:8], cur_hash[:8]
433-
)
434-
)
406+
logger.warn("git hash values are different. {}(saved) != {}(current)".format(saved_hash[:8], cur_hash[:8]))
435407
else:
436408
open(path, "w").write(cur_hash)
437409

train_nsf_sim_cache_sid_load_pretrain.py

+43-101
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def run(rank, n_gpus, hps):
9898
writer = SummaryWriter(log_dir=hps.model_dir)
9999
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
100100

101-
dist.init_process_group(
102-
backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
103-
)
101+
dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank)
104102
torch.manual_seed(hps.train.seed)
105103
if torch.cuda.is_available():
106104
torch.cuda.set_device(rank)
@@ -176,93 +174,55 @@ def run(rank, n_gpus, hps):
176174
net_d = DDP(net_d)
177175

178176
try: # 如果能加载自动resume
179-
_, _, _, epoch_str = utils.load_checkpoint(
180-
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
181-
) # D多半加载没事
177+
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d) # D多半加载没事
182178
if rank == 0:
183179
logger.info("loaded D")
184180
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
185-
_, _, _, epoch_str = utils.load_checkpoint(
186-
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
187-
)
181+
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
182+
183+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
184+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
188185

189-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
190-
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
191-
)
192-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
193-
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
194-
)
195-
196186
# epoch_str = 1
197187
# global_step = 0
198188
except: # 如果首次不能加载,加载pretrain
199189
# traceback.print_exc()
200190
epoch_str = 0
201191
global_step = 0
202192
if hps.pretrainG != "":
203-
204193
if rank == 0:
205194
logger.info("loaded pretrained %s" % (hps.pretrainG))
206-
print(
207-
net_g.module.load_state_dict(
208-
torch.load(hps.pretrainG, map_location="cpu")["model"]
209-
)
210-
) ##测试不加载优化器
195+
print(net_g.module.load_state_dict(torch.load(hps.pretrainG, map_location="cpu")["model"])) ##测试不加载优化器
211196
if hps.pretrainD != "":
212-
213197
if rank == 0:
214198
logger.info("loaded pretrained %s" % (hps.pretrainD))
215-
print(
216-
net_d.module.load_state_dict(
217-
torch.load(hps.pretrainD, map_location="cpu")["model"]
218-
)
219-
)
220-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
221-
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1
222-
)
223-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
224-
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1
225-
)
199+
print(net_d.module.load_state_dict(torch.load(hps.pretrainD, map_location="cpu")["model"]))
200+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1)
201+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 1)
226202
global_step = epoch_str * len(train_loader)
227203

228-
229-
230204
scaler = GradScaler(enabled=hps.train.fp16_run)
231205

232206
cache = []
233207
# for epoch in tqdm.tqdm(range(epoch_str, hps.train.epochs + 1), desc="Training progress", position=1, leave=True):
234-
with tqdm.tqdm(total=(hps.total_epoch), desc=f"Training progress, last ckpt saved at epoch: {epoch_str}", position=0, leave=True, initial=0 if global_step == 0 else epoch_str, dynamic_ncols=True) as pbar:
208+
209+
# disable if rank is not 0 to avoid duplicate main progress bars on each process
210+
with tqdm.tqdm(
211+
total=(hps.total_epoch),
212+
desc=f"Training progress, last ckpt saved at epoch: {epoch_str}",
213+
position=0,
214+
leave=True,
215+
initial=0 if global_step == 0 else epoch_str,
216+
dynamic_ncols=True,
217+
disable=(rank != 0),
218+
) as pbar:
235219
for epoch in range(epoch_str + 1, hps.train.epochs + 1):
236220
if rank == 0:
237221
train_and_evaluate(
238-
rank,
239-
epoch,
240-
hps,
241-
[net_g, net_d],
242-
[optim_g, optim_d],
243-
[scheduler_g, scheduler_d],
244-
scaler,
245-
[train_loader, None],
246-
logger,
247-
[writer, writer_eval],
248-
cache,
249-
pbar
222+
rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], logger, [writer, writer_eval], cache, pbar
250223
)
251224
else:
252-
train_and_evaluate(
253-
rank,
254-
epoch,
255-
hps,
256-
[net_g, net_d],
257-
[optim_g, optim_d],
258-
[scheduler_g, scheduler_d],
259-
scaler,
260-
[train_loader, None],
261-
None,
262-
None,
263-
cache,
264-
pbar
265-
)
225+
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None, cache, pbar)
266226
pbar.update(1)
267227
scheduler_g.step()
268228
scheduler_d.step()
@@ -540,36 +500,25 @@ def train_and_evaluate(
540500
# /Run steps
541501

542502
if epoch % hps.save_every_epoch == 0 and rank == 0:
543-
if hps.if_latest == 0:
544-
utils.save_checkpoint(
545-
net_g,
546-
optim_g,
547-
hps.train.learning_rate,
548-
epoch,
549-
os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
550-
)
551-
utils.save_checkpoint(
552-
net_d,
553-
optim_d,
554-
hps.train.learning_rate,
555-
epoch,
556-
os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
557-
)
558-
else:
559-
utils.save_checkpoint(
560-
net_g,
561-
optim_g,
562-
hps.train.learning_rate,
563-
epoch,
564-
os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
565-
)
566-
utils.save_checkpoint(
567-
net_d,
568-
optim_d,
569-
hps.train.learning_rate,
570-
epoch,
571-
os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
572-
)
503+
utils.save_checkpoint(
504+
net_g,
505+
optim_g,
506+
hps.train.learning_rate,
507+
epoch,
508+
hps.model_dir,
509+
"G_{}.pth".format(global_step),
510+
True if hps.if_latest == 1 else False,
511+
)
512+
utils.save_checkpoint(
513+
net_d,
514+
optim_d,
515+
hps.train.learning_rate,
516+
epoch,
517+
hps.model_dir,
518+
"D_{}.pth".format(global_step),
519+
True if hps.if_latest == 1 else False,
520+
)
521+
573522
pbar.set_description(f"Training progress, last ckpt saved at epoch {epoch}")
574523
if rank == 0 and hps.save_every_weights == "1":
575524
if hasattr(net_g, "module"):
@@ -603,14 +552,7 @@ def train_and_evaluate(
603552
ckpt = net_g.module.state_dict()
604553
else:
605554
ckpt = net_g.state_dict()
606-
logger.info(
607-
"saving final ckpt:%s"
608-
% (
609-
savee(
610-
ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
611-
)
612-
)
613-
)
555+
logger.info("saving final ckpt:%s" % (savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps)))
614556
sleep(1)
615557
os._exit(2333333)
616558

0 commit comments

Comments
 (0)