-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
110 lines (85 loc) · 3.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import division
import os
import argparse
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.optim as optim
from util.parser import load_classes, parse_data_config, parse_model_configuration
from util.model import Darknet
from util.utils import weights_init_normal
from util.datasets import ListDataset
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=30, help="number of epochs")
parser.add_argument("--image_folder", type=str, default="data/samples", help="path to dataset")
parser.add_argument("--batch_size", type=int, default=16, help="size of each image batch")
parser.add_argument("--model_config_path", type=str, default="config/yolov3.cfg", help="path to model config file")
parser.add_argument("--data_config_path", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")
parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="directory where model checkpoints are saved")
parser.add_argument("--use_cuda", type=bool, default=True, help="whether to use cuda if available")
args = parser.parse_args()
cuda = torch.cuda.is_available() and args.use_cuda
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
classes = load_classes(args.class_path)
# Get data configuration
data_config = parse_data_config(args.data_config_path)
train_path = data_config["train"]
# Get hyper parameters
hyperparams = parse_model_configuration(args.model_config_path)[0]
learning_rate = float(hyperparams["learning_rate"])
momentum = float(hyperparams["momentum"])
decay = float(hyperparams["decay"])
burn_in = int(hyperparams["burn_in"])
# Initiate model
model = Darknet(args.model_config_path)
if args.weights_path:
model.load_weights(args.weights_path)
else:
model.apply(weights_init_normal)
if cuda:
model = model.cuda()
model.train()
# Get dataloader
dataloader = DataLoader(
ListDataset(train_path), batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu
)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
for epoch in range(args.epochs):
for batch_i, (_, imgs, targets) in enumerate(dataloader):
imgs = Variable(imgs.type(Tensor))
targets = Variable(targets.type(Tensor), requires_grad=False)
optimizer.zero_grad()
print('aa')
loss = model(imgs, targets)
loss.backward()
optimizer.step()
print(
"[Epoch %d/%d, Batch %d/%d] [Losses: x %f, y %f, w %f, h %f, conf %f, cls %f, total %f, recall: %.5f, precision: %.5f]"
% (
epoch,
args.epochs,
batch_i,
len(dataloader),
model.losses["x"],
model.losses["y"],
model.losses["w"],
model.losses["h"],
model.losses["conf"],
model.losses["cls"],
loss.item(),
model.losses["recall"],
model.losses["precision"],
)
)
model.seen += imgs.size(0)
if epoch % args.checkpoint_interval == 0:
model.save_weights("%s/%d.weights" % (args.checkpoint_dir, epoch))