-
Notifications
You must be signed in to change notification settings - Fork 17
/
ae_runs.py
64 lines (53 loc) · 2.18 KB
/
ae_runs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
'''Bla bla
'''
from pandas import IndexSlice
from numpy.random import shuffle
from ae_learner import AELearner
from img_transforms import UnNormalizeTransform
#learner_1 = AELearner(run_label='simple test run',
# raw_csv_toc='../../Desktop/Fungi/toc_full.csv', raw_csv_root='../../Desktop/Fungi',
# dataset_type='grid basic',
# loader_batch_size=64, iselector=[0,1,2,3,4,5],
# lr_init=0.01, freeze_encoder=True,
# random_seed=79)
#
#tt = IndexSlice[:, :, :, :, :, ['Cantharellaceae'], :, :, :]
#learner_2 = AELearner(raw_csv_toc='../../Desktop/Fungi/toc_full.csv', raw_csv_root='../../Desktop/Fungi',
# loader_batch_size=128, selector=tt,
# iselector=list(range(120)),
# lr_init=0.03, scheduler_step_size=12,
# freeze_encoder=False,
# random_seed=79)
tt = IndexSlice[:, :, :, :, :, ['Cantharellaceae', 'Amanitaceae'], :, :, :]
v1 = list(range(759))
v2 = list(range(759, 2429))
shuffle(v1)
shuffle(v2)
vv = v1[:100] + v2[:300]
learner_3 = AELearner(raw_csv_toc='../../Desktop/Fungi/toc_full.csv', raw_csv_root='../../Desktop/Fungi',
dataset_type='grid basic',
loader_batch_size=64, selector=tt,
iselector=vv,
lr_init=0.01, scheduler_step_size=8, scheduler_gamma=0.1,
freeze_encoder=False,
random_seed=79)
def train_from_scratch():
learner_1.train(6)
learner_1.save_model('ae_learner_run_1')
def train_from_existing():
learner_1.load_model('ae_learner_run_1')
learner_1.train(6)
learner_1.save_model('ae_learner_run_2')
def eval_from_existing():
learner_1.load_model('ae_learner_run_2')
for out in learner_1.eval_model(untransform=UnNormalizeTransform()):
print (out.shape)
def train_bigger():
learner_3.load_model('ae_learner_2_bigger')
learner_3.train(16)
learner_3.save_model('ae_learner_2_bigger_2')
if __name__ == '__main__':
#train_from_scratch()
#train_from_existing()
#eval_from_existing()
train_bigger()