-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
61 lines (52 loc) · 2.2 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
from src.data_helper import DataProcessor
from configs import configs, img_configs
from src.trainer import Trainer
# from models.ProNet import ProNet
from models.SimpleProNet import ProNet
from models.ProNet2 import ProNet2
from models.VGG import VGG
from models.dla import DLA
from models.densenet import densenet_cifar
import cv2
from utils import *
def get_dataset(file_dir, file_name, iter, saved=False):
if configs['preprocess']:
dataset = None
dataset = DataProcessor.load_data_from_binary_file(file_dir, file_name)
dataset = DataProcessor.generate_data(dataset,
img_configs['block-dim'],
img_configs['block-size'],
n_jobs=2)
# DataProcessor.save_data_to_binary_file(dataset, "input/data/64x64/train_batchs/dataset_{}.bin".format(iter))
trainset, testset = DataProcessor.split_dataset(dataset, 0.95, saved=False)
else:
dataset = DataProcessor.load_data_from_binary_file("input/data/64x64/train_batchs/","dataset_{}.bin".format(iter))
trainset, testset = DataProcessor.split_dataset(dataset, 0.97, saved=False)
return trainset, testset
def main():
configs['preprocess'] = True
configs['num-dataset'] = 200
file_dir = "input/data/64x64/images/"
trainer = Trainer(model=VGG('VGG7'),
lr=0.003,
loss='bce',
optimizer='adas',
batch_size=256,
n_repeats=1,
save_every=500
)
trainer.model.load(0, 1580)
for i in range(0, configs['num-dataset']):
file_name = "image_data_batch_{}.bin".format(i)
trainset, testset = get_dataset(file_dir, file_name, i, saved=False)
print("Train set size: ", len(trainset['data']))
print("Test set size: ", len(testset['data']))
trainer.train_loader = trainset
trainer.test_loader = testset
trainer.train()
trainer.train_loader = None
trainer.test_loader = None
trainset, testset = None, None
if __name__ == "__main__":
main()