-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
109 lines (91 loc) · 2.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import shutil
import keras
import tensorflow as tf
from ultralytics import YOLO
from config import (
BATCH_SIZE,
MAIN_MODEL_PATH_PASCALVOC,
MAIN_MODEL_PATH_ULTRALYTICS,
IMG_RESIZE,
CONFIDENCE,
)
from dataset import Dataset
import models
from utils import (
setup,
)
def train_pascalvoc(dataset: Dataset) -> None:
model_data = models._get_yolov8_pascalvoc_model_data()
model: keras.Model = model_data.model
for layer in model.layers[:-20]:
layer.trainable = False
model.compile(
classification_loss='binary_crossentropy',
box_loss='ciou',
optimizer=keras.optimizers.Adam(0.001),
jit_compile=False,
)
(train_dataset, test_dataset) = dataset.load_data_as_keras(model_data.target_class)
model.fit(train_dataset, epochs=10, validation_data=test_dataset)
model.save(MAIN_MODEL_PATH_PASCALVOC)
print(f"Trained model saved: {MAIN_MODEL_PATH_PASCALVOC}")
def train_ultralytics(dataset: Dataset) -> None:
model_data = models._get_yolov8s_ultralytics_model_data()
model: YOLO = model_data.model
dataset.generate_ultralytics_files(model_data.target_class)
model.train(
data="dataset.yaml",
epochs=10,
freeze=[*range(5)],
patience=8,
batch=BATCH_SIZE,
imgsz=IMG_RESIZE[1],
workers=8,
pretrained=True,
resume=False,
single_cls=False,
box=5,
cls=0.3,
dfl=1,
)
results = model.val(
imgsz=IMG_RESIZE[1],
batch=BATCH_SIZE,
conf=CONFIDENCE,
iou=0.5,
save_json=False,
save_hybrid=False,
split="val"
)
exported_path = model.export(format="onnx")
shutil.copy(exported_path, MAIN_MODEL_PATH_ULTRALYTICS)
print(f"Trained model saved: {MAIN_MODEL_PATH_ULTRALYTICS}")
def main():
setup()
print(f"Devices: {[device.device_type for device in tf.config.list_physical_devices()]}\n")
model_to_function_map = {
"pascalvoc": train_pascalvoc,
"ultralytics": train_ultralytics,
}
parser = argparse.ArgumentParser()
parser.add_argument(
"-model",
type=str,
required=False,
choices=[*model_to_function_map.keys()],
default="pascalvoc",
help="Choose model to train"
)
args = parser.parse_args()
model_arg = args.model
print("Step 1/4: Prepare dataset\n")
dataset = Dataset.create_dataset()
print("\nStep 2/4: Build dataset annotations\n")
pretrained_model_data = models.get_pretrained_model_data()
dataset.create_dataset_annotations(pretrained_model_data)
print("\nStep 3/4: Train + test\n")
print(f"Model: {model_arg}")
model_to_function_map[model_arg](dataset)
if __name__ == "__main__":
main()