@@ -31,10 +31,7 @@ def go(model, bkey):
31
31
try :
32
32
new_state_dict [k ] = saved_state_dict [k ]
33
33
if saved_state_dict [k ].shape != state_dict [k ].shape :
34
- print (
35
- "shape-%s-mismatch|need-%s|get-%s"
36
- % (k , state_dict [k ].shape , saved_state_dict [k ].shape )
37
- ) #
34
+ print ("shape-%s-mismatch|need-%s|get-%s" % (k , state_dict [k ].shape , saved_state_dict [k ].shape )) #
38
35
raise KeyError
39
36
except :
40
37
# logger.info(traceback.format_exc())
@@ -52,9 +49,7 @@ def go(model, bkey):
52
49
53
50
iteration = checkpoint_dict ["iteration" ]
54
51
learning_rate = checkpoint_dict ["learning_rate" ]
55
- if (
56
- optimizer is not None and load_opt == 1
57
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
52
+ if optimizer is not None and load_opt == 1 : ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
58
53
# try:
59
54
optimizer .load_state_dict (checkpoint_dict ["optimizer" ])
60
55
# except:
@@ -106,10 +101,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
106
101
try :
107
102
new_state_dict [k ] = saved_state_dict [k ]
108
103
if saved_state_dict [k ].shape != state_dict [k ].shape :
109
- print (
110
- "shape-%s-mismatch|need-%s|get-%s"
111
- % (k , state_dict [k ].shape , saved_state_dict [k ].shape )
112
- ) #
104
+ print ("shape-%s-mismatch|need-%s|get-%s" % (k , state_dict [k ].shape , saved_state_dict [k ].shape )) #
113
105
raise KeyError
114
106
except :
115
107
# logger.info(traceback.format_exc())
@@ -123,9 +115,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
123
115
124
116
iteration = checkpoint_dict ["iteration" ]
125
117
learning_rate = checkpoint_dict ["learning_rate" ]
126
- if (
127
- optimizer is not None and load_opt == 1
128
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
118
+ if optimizer is not None and load_opt == 1 : ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
129
119
# try:
130
120
optimizer .load_state_dict (checkpoint_dict ["optimizer" ])
131
121
# except:
@@ -134,33 +124,39 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
134
124
return model , optimizer , learning_rate , iteration
135
125
136
126
137
- def save_checkpoint (model , optimizer , learning_rate , iteration , checkpoint_path ):
138
- logger .info (
139
- "Saving model and optimizer state at epoch {} to {}" .format (
140
- iteration , checkpoint_path
141
- )
142
- )
127
+ def save_checkpoint (model , optimizer , learning_rate , iteration , checkpoint_path , checkpoint_type , delete_old = False ):
128
+ # logger.info(
129
+ # "Saving model and optimizer state at epoch {} to {}".format(
130
+ # iteration, checkpoint_path
131
+ # )
132
+ # )
143
133
if hasattr (model , "module" ):
144
134
state_dict = model .module .state_dict ()
145
135
else :
146
136
state_dict = model .state_dict ()
137
+ if delete_old :
138
+ latest_checkpoint = latest_checkpoint_path (checkpoint_path , regex = ("G_*.pth" if checkpoint_type .startswith ("G" ) else "D_*.pth" ))
139
+
147
140
torch .save (
148
141
{
149
142
"model" : state_dict ,
150
143
"iteration" : iteration ,
151
144
"optimizer" : optimizer .state_dict (),
152
145
"learning_rate" : learning_rate ,
153
146
},
154
- checkpoint_path ,
147
+ os . path . join ( checkpoint_path , checkpoint_type ) ,
155
148
)
149
+ # delete after saving new checkpoint to avoid loss if save fails
150
+ if delete_old and latest_checkpoint is not None :
151
+ os .remove (latest_checkpoint )
156
152
157
153
158
154
def save_checkpoint_d (combd , sbd , optimizer , learning_rate , iteration , checkpoint_path ):
159
- logger .info (
160
- "Saving model and optimizer state at epoch {} to {}" .format (
161
- iteration , checkpoint_path
162
- )
163
- )
155
+ # logger.info(
156
+ # "Saving model and optimizer state at epoch {} to {}".format(
157
+ # iteration, checkpoint_path
158
+ # )
159
+
164
160
if hasattr (combd , "module" ):
165
161
state_dict_combd = combd .module .state_dict ()
166
162
else :
@@ -204,7 +200,7 @@ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204
200
f_list = glob .glob (os .path .join (dir_path , regex ))
205
201
f_list .sort (key = lambda f : int ("" .join (filter (str .isdigit , f ))))
206
202
x = f_list [- 1 ]
207
- print (x )
203
+ # print(x)
208
204
return x
209
205
210
206
@@ -247,9 +243,7 @@ def plot_alignment_to_numpy(alignment, info=None):
247
243
import numpy as np
248
244
249
245
fig , ax = plt .subplots (figsize = (6 , 4 ))
250
- im = ax .imshow (
251
- alignment .transpose (), aspect = "auto" , origin = "lower" , interpolation = "none"
252
- )
246
+ im = ax .imshow (alignment .transpose (), aspect = "auto" , origin = "lower" , interpolation = "none" )
253
247
fig .colorbar (im , ax = ax )
254
248
xlabel = "Decoder timestep"
255
249
if info is not None :
@@ -302,35 +296,21 @@ def get_hparams(init=True):
302
296
required = True ,
303
297
help = "checkpoint save frequency (epoch)" ,
304
298
)
305
- parser .add_argument (
306
- "-te" , "--total_epoch" , type = int , required = True , help = "total_epoch"
307
- )
308
- parser .add_argument (
309
- "-pg" , "--pretrainG" , type = str , default = "" , help = "Pretrained Discriminator path"
310
- )
311
- parser .add_argument (
312
- "-pd" , "--pretrainD" , type = str , default = "" , help = "Pretrained Generator path"
313
- )
299
+ parser .add_argument ("-te" , "--total_epoch" , type = int , required = True , help = "total_epoch" )
300
+ parser .add_argument ("-pg" , "--pretrainG" , type = str , default = "" , help = "Pretrained Discriminator path" )
301
+ parser .add_argument ("-pd" , "--pretrainD" , type = str , default = "" , help = "Pretrained Generator path" )
314
302
parser .add_argument ("-g" , "--gpus" , type = str , default = "0" , help = "split by -" )
315
- parser .add_argument (
316
- "-bs" , "--batch_size" , type = int , required = True , help = "batch size"
317
- )
318
- parser .add_argument (
319
- "-e" , "--experiment_dir" , type = str , required = True , help = "experiment dir"
320
- ) # -m
321
- parser .add_argument (
322
- "-sr" , "--sample_rate" , type = str , required = True , help = "sample rate, 32k/40k/48k"
323
- )
303
+ parser .add_argument ("-bs" , "--batch_size" , type = int , required = True , help = "batch size" )
304
+ parser .add_argument ("-e" , "--experiment_dir" , type = str , required = True , help = "experiment dir" ) # -m
305
+ parser .add_argument ("-sr" , "--sample_rate" , type = str , required = True , help = "sample rate, 32k/40k/48k" )
324
306
parser .add_argument (
325
307
"-sw" ,
326
308
"--save_every_weights" ,
327
309
type = str ,
328
310
default = "0" ,
329
311
help = "save the extracted model in weights directory when saving checkpoints" ,
330
312
)
331
- parser .add_argument (
332
- "-v" , "--version" , type = str , required = True , help = "model version"
333
- )
313
+ parser .add_argument ("-v" , "--version" , type = str , required = True , help = "model version" )
334
314
parser .add_argument (
335
315
"-f0" ,
336
316
"--if_f0" ,
@@ -414,11 +394,7 @@ def get_hparams_from_file(config_path):
414
394
def check_git_hash (model_dir ):
415
395
source_dir = os .path .dirname (os .path .realpath (__file__ ))
416
396
if not os .path .exists (os .path .join (source_dir , ".git" )):
417
- logger .warn (
418
- "{} is not a git repository, therefore hash value comparison will be ignored." .format (
419
- source_dir
420
- )
421
- )
397
+ logger .warn ("{} is not a git repository, therefore hash value comparison will be ignored." .format (source_dir ))
422
398
return
423
399
424
400
cur_hash = subprocess .getoutput ("git rev-parse HEAD" )
@@ -427,11 +403,7 @@ def check_git_hash(model_dir):
427
403
if os .path .exists (path ):
428
404
saved_hash = open (path ).read ()
429
405
if saved_hash != cur_hash :
430
- logger .warn (
431
- "git hash values are different. {}(saved) != {}(current)" .format (
432
- saved_hash [:8 ], cur_hash [:8 ]
433
- )
434
- )
406
+ logger .warn ("git hash values are different. {}(saved) != {}(current)" .format (saved_hash [:8 ], cur_hash [:8 ]))
435
407
else :
436
408
open (path , "w" ).write (cur_hash )
437
409
0 commit comments