|
39 | 39 | LAT_D = 8
|
40 | 40 | SUB_D = 12
|
41 | 41 |
|
42 |
| -MAXEPOCHS = 3001 # 350 # 1500 |
| 42 | +MAXEPOCHS = 3000 # 350 # 1500 # 3000 # 3001 |
43 | 43 | OFFSET_EP = 0
|
44 | 44 | REAL_MAX_EP = True
|
45 | 45 |
|
46 | 46 | if X_EXT_D != len(EXTRA_QUEST) or X_D != MAIN_QUEST_N: raise RuntimeError('Inconsistent constants.')
|
47 | 47 |
|
48 | 48 | version: Opt[str]
|
49 |
| -name, version = DEFAULTNAME, None |
| 49 | +save: Opt[str] |
| 50 | +name, version, save = DEFAULTNAME, None, None |
50 | 51 | if __name__ == '__main__':
|
51 | 52 | parser = argparse.ArgumentParser()
|
52 | 53 | parser.add_argument("--name", default=str(name))
|
53 | 54 | parser.add_argument("--ver", default=version)
|
| 55 | + parser.add_argument("--save", default=save) |
54 | 56 | args = parser.parse_args()
|
55 |
| - name, version = args.name, args.ver |
| 57 | + name, version, save = args.name, args.ver, args.save |
56 | 58 |
|
57 | 59 |
|
58 | 60 | parent_dir = path.dirname(path.dirname(path.abspath(__file__)))
|
@@ -413,14 +415,14 @@ def configure_optimizers(self):
|
413 | 415 |
|
414 | 416 |
|
415 | 417 | autoencoder = LightVAE(offset_step=183 * OFFSET_EP)
|
416 |
| -trainer_ = pl.Trainer(max_epochs=maxepochs, logger=logger, check_val_every_n_epoch=5, callbacks=[ |
417 |
| - plot_callback, git_dir_sha, |
418 |
| - ModelCheckpoint(every_n_epochs=10, save_top_k=-1), |
419 |
| -]) |
| 418 | +trainer_ = pl.Trainer(max_epochs=maxepochs, logger=logger, check_val_every_n_epoch=5 if save is None else 1, |
| 419 | + callbacks=[plot_callback, git_dir_sha, |
| 420 | + ModelCheckpoint(every_n_epochs=10 if save is None else 1, save_top_k=-1) |
| 421 | + ]) |
420 | 422 |
|
421 | 423 |
|
422 | 424 | if __name__ == '__main__':
|
423 |
| - trainer_.fit(autoencoder, train_loaders, test_loader, ckpt_path=checkpoint_path) |
| 425 | + trainer_.fit(autoencoder, train_loaders, test_loader, ckpt_path=checkpoint_path if save is None else save) |
424 | 426 | if not autoencoder.successful_run:
|
425 | 427 | raise RuntimeError('Unsuccessful. Skipping this train run.')
|
426 | 428 | raise RuntimeError('Force skip all train runs.')
|
0 commit comments