Skip to content

Commit 6c0ecaa

Browse files
authored
v0.3.2
1 parent dc0dc84 commit 6c0ecaa

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

train/main.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
# with successful or good enough to try runs copied from ../train_log_search_model
1010
# (Placing "skip" file into the directory would skip it).
1111
#
12+
# You can also directly run training.py with --save filepath provided to
13+
# save a model without history.
14+
#
1215
# Tensorboard info: Don't start from CWD or parent, use for example ~
1316
# > conda activate nn
1417
# > tensorboard --logdir <abs-path-to-log-dir>

train/training.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,22 @@
3939
LAT_D = 8
4040
SUB_D = 12
4141

42-
MAXEPOCHS = 3001 # 350 # 1500
42+
MAXEPOCHS = 3000 # 350 # 1500 # 3000 # 3001
4343
OFFSET_EP = 0
4444
REAL_MAX_EP = True
4545

4646
if X_EXT_D != len(EXTRA_QUEST) or X_D != MAIN_QUEST_N: raise RuntimeError('Inconsistent constants.')
4747

4848
version: Opt[str]
49-
name, version = DEFAULTNAME, None
49+
save: Opt[str]
50+
name, version, save = DEFAULTNAME, None, None
5051
if __name__ == '__main__':
5152
parser = argparse.ArgumentParser()
5253
parser.add_argument("--name", default=str(name))
5354
parser.add_argument("--ver", default=version)
55+
parser.add_argument("--save", default=save)
5456
args = parser.parse_args()
55-
name, version = args.name, args.ver
57+
name, version, save = args.name, args.ver, args.save
5658

5759

5860
parent_dir = path.dirname(path.dirname(path.abspath(__file__)))
@@ -413,14 +415,14 @@ def configure_optimizers(self):
413415

414416

415417
autoencoder = LightVAE(offset_step=183 * OFFSET_EP)
416-
trainer_ = pl.Trainer(max_epochs=maxepochs, logger=logger, check_val_every_n_epoch=5, callbacks=[
417-
plot_callback, git_dir_sha,
418-
ModelCheckpoint(every_n_epochs=10, save_top_k=-1),
419-
])
418+
trainer_ = pl.Trainer(max_epochs=maxepochs, logger=logger, check_val_every_n_epoch=5 if save is None else 1,
419+
callbacks=[plot_callback, git_dir_sha,
420+
ModelCheckpoint(every_n_epochs=10 if save is None else 1, save_top_k=-1)
421+
])
420422

421423

422424
if __name__ == '__main__':
423-
trainer_.fit(autoencoder, train_loaders, test_loader, ckpt_path=checkpoint_path)
425+
trainer_.fit(autoencoder, train_loaders, test_loader, ckpt_path=checkpoint_path if save is None else save)
424426
if not autoencoder.successful_run:
425427
raise RuntimeError('Unsuccessful. Skipping this train run.')
426428
raise RuntimeError('Force skip all train runs.')

0 commit comments

Comments
 (0)