-
Notifications
You must be signed in to change notification settings - Fork 17
/
ic_runs.py
57 lines (47 loc) · 2.29 KB
/
ic_runs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
'''Bla bla
'''
from pandas import IndexSlice
from numpy.random import shuffle
from torch.utils.data import DataLoader
from ic_learner import ICLearner
from fungidata import factory
label_binary_cf = ('Family == "Cantharellaceae"', 'Family == "Amanitaceae"')
tt = IndexSlice[:, :, :, :, :, ['Cantharellaceae', 'Amanitaceae'], :, :, :]
v1 = list(range(759))
v2 = list(range(759, 2429))
shuffle(v1)
shuffle(v2)
vv = v1[:500] + v2[:1000]
vv_test = v1[500:700] + v2[1000:1300]
dataset2 = factory.create('full basic labelled', csv_file='../../Desktop/Fungi/toc_full.csv',
img_root_dir='../../Desktop/Fungi', label_keys=label_binary_cf,
selector=tt, iselector=vv_test, min_dim=299)
dataloader_test = DataLoader(dataset2, batch_size=16, shuffle=False)
learner_1 = ICLearner(run_label='simple classification test run',
raw_csv_toc='../../Desktop/Fungi/toc_full.csv', raw_csv_root='../../Desktop/Fungi',
save_tmp_name='model_training_ic',
ic_model='inception_v3', min_dim=300,
loader_batch_size=16,
selector=tt, iselector=vv,
label_keys=label_binary_cf,
lr_init=0.01, scheduler_gamma=0.25, scheduler_step_size=5,
random_seed=79, test_dataloader=dataloader_test, test_datasetsize=len(dataset2))
#learner_2 = ICLearner(run_label='simple classification test run with data augmentation',
# raw_csv_toc='../../Desktop/Fungi/toc_full.csv', raw_csv_root='../../Desktop/Fungi',
# loader_batch_size=16,
# selector=tt, iselector=vv,
# label_keys=label_binary_cf,
# dataset_type='full aug labelled',
# ic_model='inception_v3',min_dim=299,
# aug_multiplicity=1, aug_label='random_resized_crop',
# lr_init=0.01, scheduler_step_size=7, scheduler_gamma=0.1,
# random_seed=79, test_dataloader=dataloader_test)
def train_simple_ic():
learner_1.train(20)
learner_1.save_model('ic_run_1')
#def train_aug_ic():
# learner_2.train(16)
# learner_2.save_model('ic_bigrun_2')
if __name__ == '__main__':
train_simple_ic()
#train_aug_ic()