-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
56 lines (47 loc) · 1.36 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import keras
import keras_cv
from ultralytics import YOLO
from config import (
BOUNDING_BOX_FORMAT,
IMG_RESIZE,
MAIN_MODEL_PATH_ULTRALYTICS,
)
from utils import (
ModelData
)
preprocess_model = keras.Sequential([
keras.layers.Input(shape=(None, None, 3)),
keras_cv.layers.Resizing(*IMG_RESIZE, pad_to_aspect_ratio=True, bounding_box_format=BOUNDING_BOX_FORMAT),
])
def get_pretrained_model_data() -> ModelData:
return _get_yolov8_pascalvoc_model_data()
def get_main_model_data() -> ModelData:
return _get_yolov8s_ultralytics_model_data_trained()
def _get_yolov8_pascalvoc_model_data() -> ModelData:
return ModelData(
keras_cv.models.YOLOV8Detector.from_preset(
"yolo_v8_m_pascalvoc",
bounding_box_format=BOUNDING_BOX_FORMAT,
num_classes=20
),
preprocess_model,
14,
)
def _get_yolov8s_ultralytics_model_data() -> ModelData:
return ModelData(
YOLO("yolov8s.pt", task="detect"),
preprocess_model,
0,
)
def _get_yolov8s_ultralytics_model_data_trained() -> ModelData:
return ModelData(
YOLO(MAIN_MODEL_PATH_ULTRALYTICS, task="detect"),
preprocess_model,
0,
)
def main():
model_data = _get_yolov8s_ultralytics_model_data()
model = model_data.model
model.summary()
if __name__ == "__main__":
main()